Files
NexaVPN/backend/internal/device/service.go
nessi e3bd6d3b96 feat: add DNS server routes to WireGuard profiles and gateway firewall rules
Add mergeProfileAllowedIPs function to combine policy destinations with DNS server routes in device enrollment and rotation. Add dnsServerRoute helper to convert DNS server IPs to /32 CIDR notation. Update BuildSyncBundle query to include gateway DNS servers in peer data. Add DNSServers field to wireguard.Peer struct. Update gateway nftables configuration to allow UDP/TCP port 53 traffic from assigned IPs to DNS servers before
2026-03-18 08:48:08 +01:00

211 lines
5.8 KiB
Go

package device
import (
"context"
"strings"
"github.com/google/uuid"
"nexavpn/backend/internal/gateway"
"nexavpn/backend/internal/ipam"
"nexavpn/backend/internal/policy"
"nexavpn/backend/internal/profile"
)
type Service struct {
repo Repository
policyService *policy.Service
gatewayService *gateway.Service
ipamService *ipam.Service
}
func NewService(repo Repository, policyService *policy.Service, gatewayService *gateway.Service, ipamService *ipam.Service) *Service {
return &Service{
repo: repo,
policyService: policyService,
gatewayService: gatewayService,
ipamService: ipamService,
}
}
func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequest, privateKeyPlaceholder string) (EnrollmentResponse, error) {
selectedGateway, err := s.gatewayService.SelectActive(ctx)
if err != nil {
return EnrollmentResponse{}, err
}
assignedIP, err := s.repo.FindNextAvailableIP(ctx, selectedGateway.ID, selectedGateway.VPNCIDR, 10)
if err != nil {
assignedIP, err = s.ipamService.Allocate(selectedGateway.VPNCIDR, 10)
if err != nil {
return EnrollmentResponse{}, err
}
}
destinations, err := s.policyService.ResolveDestinations(ctx, userID, nil)
if err != nil {
return EnrollmentResponse{}, err
}
if len(destinations) == 0 {
destinations = []string{"172.16.10.0/24"}
}
enrollment, err := s.repo.Enroll(ctx, userID, selectedGateway.ID, input, assignedIP, selectedGateway.DNSServers, destinations)
if err != nil {
return EnrollmentResponse{}, err
}
destinations, err = s.policyService.ResolveDestinations(ctx, userID, &enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
if len(destinations) == 0 {
destinations = []string{"172.16.10.0/24"}
}
profileAllowedIPs := mergeProfileAllowedIPs(destinations, selectedGateway.DNSServers)
enrollment.Peer = PeerView{
AssignedIP: assignedIP,
DNSServers: selectedGateway.DNSServers,
AllowedIPs: profileAllowedIPs,
Gateway: GatewayView{
ID: selectedGateway.ID,
Name: selectedGateway.Name,
Endpoint: selectedGateway.Endpoint,
PublicKey: selectedGateway.PublicKey,
},
ProfileRevision: 1,
}
for _, destination := range destinations {
enrollment.Resources = append(enrollment.Resources, Resource{
Type: "cidr",
Value: destination,
Label: destination,
})
}
enrollment.Profile = ProfileView{
Format: "wireguard",
Content: profile.BuildWireGuardConfig(profile.BuildInput{
PrivateKey: privateKeyPlaceholder,
Address: assignedIP,
DNSServers: selectedGateway.DNSServers,
ServerPublicKey: selectedGateway.PublicKey,
ServerEndpoint: selectedGateway.Endpoint,
AllowedIPs: profileAllowedIPs,
PersistentKeepal: 25,
}),
}
return enrollment, nil
}
func (s *Service) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
return s.repo.ListByUser(ctx, userID)
}
func (s *Service) ListAll(ctx context.Context) ([]Device, error) {
return s.repo.ListAll(ctx)
}
func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return EnrollmentResponse{}, err
}
return withDebugProfile(enrollment), nil
}
func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetEnrollmentByDeviceID(ctx, deviceID)
if err != nil {
return EnrollmentResponse{}, err
}
return withDebugProfile(enrollment), nil
}
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return ConnectionStatus{
Status: "disconnected",
Resources: []Resource{},
}, nil
}
lastSync := "just now"
return ConnectionStatus{
Status: "provisioned",
AssignedIP: enrollment.Peer.AssignedIP,
LastSyncTime: &lastSync,
Resources: enrollment.Resources,
}, nil
}
func (s *Service) Revoke(ctx context.Context, deviceID uuid.UUID) error {
return s.repo.Revoke(ctx, deviceID)
}
func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error {
return s.repo.Rotate(ctx, deviceID)
}
func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
profileAllowedIPs := mergeProfileAllowedIPs(enrollment.Peer.AllowedIPs, enrollment.Peer.DNSServers)
enrollment.Profile = ProfileView{
Format: "wireguard",
Content: profile.BuildWireGuardConfig(profile.BuildInput{
PrivateKey: "__CLIENT_PRIVATE_KEY_REQUIRED__",
Address: enrollment.Peer.AssignedIP,
DNSServers: enrollment.Peer.DNSServers,
ServerPublicKey: enrollment.Peer.Gateway.PublicKey,
ServerEndpoint: enrollment.Peer.Gateway.Endpoint,
AllowedIPs: profileAllowedIPs,
PersistentKeepal: 25,
}),
}
return enrollment
}
func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string {
seen := make(map[string]struct{}, len(destinations)+len(dnsServers))
merged := make([]string, 0, len(destinations)+len(dnsServers))
for _, destination := range destinations {
destination = strings.TrimSpace(destination)
if destination == "" {
continue
}
if _, exists := seen[destination]; exists {
continue
}
seen[destination] = struct{}{}
merged = append(merged, destination)
}
for _, dnsServer := range dnsServers {
route := dnsServerRoute(dnsServer)
if route == "" {
continue
}
if _, exists := seen[route]; exists {
continue
}
seen[route] = struct{}{}
merged = append(merged, route)
}
return merged
}
func dnsServerRoute(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if strings.Contains(value, "/") {
return value
}
return value + "/32"
}