Files
NexaVPN/backend/internal/device/service.go
nessi 784971f111 feat: add environment-based DNS server override support with service-aware fallback logic
Add dnsServersForProfile and dnsServersForPeer helpers with conditional DNS server selection based on service presence. Use NEXAVPN_CLIENT_DNS_SERVERS override when services are configured, otherwise fall back to DEFAULT_DNS_SERVERS or gateway base DNS servers.

Replace direct gateway DNS server usage in Enroll and applyCurrentPolicy with profileDNSServers variable. Update BuildSyncBundle to scan gateway DNS servers separately
2026-03-20 08:30:35 +01:00

438 lines
12 KiB
Go

package device
import (
"context"
"fmt"
"os"
"strings"
"github.com/google/uuid"
"nexavpn/backend/internal/gateway"
"nexavpn/backend/internal/ipam"
"nexavpn/backend/internal/policy"
"nexavpn/backend/internal/profile"
)
type Service struct {
repo Repository
policyService *policy.Service
gatewayService *gateway.Service
ipamService *ipam.Service
}
func NewService(repo Repository, policyService *policy.Service, gatewayService *gateway.Service, ipamService *ipam.Service) *Service {
return &Service{
repo: repo,
policyService: policyService,
gatewayService: gatewayService,
ipamService: ipamService,
}
}
func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequest, privateKeyPlaceholder string) (EnrollmentResponse, error) {
selectedGateway, err := s.gatewayService.SelectActive(ctx)
if err != nil {
return EnrollmentResponse{}, err
}
assignedIP, err := s.repo.FindNextAvailableIP(ctx, selectedGateway.ID, selectedGateway.VPNCIDR, 10)
if err != nil {
assignedIP, err = s.ipamService.Allocate(selectedGateway.VPNCIDR, 10)
if err != nil {
return EnrollmentResponse{}, err
}
}
destinations, err := s.policyService.ResolveDestinations(ctx, userID, nil)
if err != nil {
return EnrollmentResponse{}, err
}
if len(destinations) == 0 {
destinations = []string{"172.16.10.0/24"}
}
enrollment, err := s.repo.Enroll(ctx, userID, selectedGateway.ID, input, assignedIP, selectedGateway.DNSServers, destinations)
if err != nil {
return EnrollmentResponse{}, err
}
destinations, err = s.policyService.ResolveDestinations(ctx, userID, &enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
if len(selectedDestinations) == 0 && len(servicesForSelectedProfile(availableProfiles, selectedProfileID)) == 0 {
selectedDestinations = []string{"172.16.10.0/24"}
}
if len(selectedDestinations) == 0 {
selectedDestinations = destinations
}
selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
profileDNSServers := dnsServersForProfile(selectedGateway.DNSServers, selectedServices)
profileAllowedIPs := mergeProfileAllowedIPs(
append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
profileDNSServers,
)
enrollment.Peer = PeerView{
AssignedIP: assignedIP,
DNSServers: profileDNSServers,
AllowedIPs: profileAllowedIPs,
Gateway: GatewayView{
ID: selectedGateway.ID,
Name: selectedGateway.Name,
Endpoint: selectedGateway.Endpoint,
PublicKey: selectedGateway.PublicKey,
},
ProfileRevision: 1,
}
enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID
enrollment.Profile = ProfileView{
Format: "wireguard",
Content: profile.BuildWireGuardConfig(profile.BuildInput{
PrivateKey: privateKeyPlaceholder,
Address: assignedIP,
DNSServers: profileDNSServers,
ServerPublicKey: selectedGateway.PublicKey,
ServerEndpoint: selectedGateway.Endpoint,
AllowedIPs: profileAllowedIPs,
PersistentKeepal: 25,
}),
}
return enrollment, nil
}
func (s *Service) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
return s.repo.ListByUser(ctx, userID)
}
func (s *Service) ListAll(ctx context.Context) ([]Device, error) {
return s.repo.ListAll(ctx)
}
func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return EnrollmentResponse{}, err
}
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetEnrollmentByDeviceID(ctx, deviceID)
if err != nil {
return EnrollmentResponse{}, err
}
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return EnrollmentResponse{}, err
}
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
var exists bool
for _, profile := range profiles {
if profile.ID == profileID {
exists = true
break
}
}
if !exists {
return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device")
}
if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil {
return EnrollmentResponse{}, err
}
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return ConnectionStatus{
Status: "disconnected",
Resources: []Resource{},
}, nil
}
enrollment, err = s.applyCurrentPolicy(ctx, enrollment)
if err != nil {
return ConnectionStatus{}, err
}
lastSync := "just now"
return ConnectionStatus{
Status: "provisioned",
AssignedIP: enrollment.Peer.AssignedIP,
LastSyncTime: &lastSync,
Resources: enrollment.Resources,
}, nil
}
func (s *Service) Revoke(ctx context.Context, deviceID uuid.UUID) error {
return s.repo.Revoke(ctx, deviceID)
}
func (s *Service) Delete(ctx context.Context, deviceID uuid.UUID) error {
return s.repo.Delete(ctx, deviceID)
}
func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error {
return s.repo.Rotate(ctx, deviceID)
}
func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
profileAllowedIPs := enrollment.Peer.AllowedIPs
enrollment.Profile = ProfileView{
Format: "wireguard",
Content: profile.BuildWireGuardConfig(profile.BuildInput{
PrivateKey: "__CLIENT_PRIVATE_KEY_REQUIRED__",
Address: enrollment.Peer.AssignedIP,
DNSServers: enrollment.Peer.DNSServers,
ServerPublicKey: enrollment.Peer.Gateway.PublicKey,
ServerEndpoint: enrollment.Peer.Gateway.Endpoint,
AllowedIPs: profileAllowedIPs,
PersistentKeepal: 25,
}),
}
return enrollment
}
func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) {
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
if len(selectedDestinations) == 0 && len(selectedServices) == 0 {
selectedDestinations = []string{"172.16.10.0/24"}
}
enrollment.Peer.DNSServers = dnsServersForProfile(enrollment.Peer.DNSServers, selectedServices)
enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(
append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
enrollment.Peer.DNSServers,
)
return withDebugProfile(enrollment), nil
}
func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) {
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID)
if err != nil {
return nil, nil, nil, err
}
availableProfiles := make([]AccessProfile, 0, len(profiles))
for _, profile := range profiles {
services := make([]AccessService, 0, len(profile.Services))
for _, service := range profile.Services {
services = append(services, AccessService{
ID: service.ID,
Name: service.Name,
Description: service.Description,
Domain: service.Domain,
UpstreamIP: service.UpstreamIP,
ProxyIP: service.ProxyIP,
Ports: service.Ports,
})
}
availableProfiles = append(availableProfiles, AccessProfile{
ID: profile.ID,
Name: profile.Name,
Description: profile.Description,
FullTunnel: profile.FullTunnel,
Destinations: profile.Destinations,
Services: services,
})
}
if len(availableProfiles) == 0 {
return nil, nil, nil, nil
}
selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID)
if err != nil {
return nil, nil, nil, err
}
for _, profile := range availableProfiles {
if selectedProfileID != nil && profile.ID == *selectedProfileID {
return availableProfiles, selectedProfileID, profile.Destinations, nil
}
}
fallback := availableProfiles[0]
if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil {
return nil, nil, nil, err
}
return availableProfiles, &fallback.ID, fallback.Destinations, nil
}
func resourcesFromDestinations(destinations []string) []Resource {
resources := make([]Resource, 0, len(destinations))
for _, destination := range destinations {
resources = append(resources, Resource{
Type: "cidr",
Value: destination,
Label: destination,
})
}
return resources
}
func resourcesFromProfile(destinations []string, services []AccessService) []Resource {
resources := resourcesFromDestinations(destinations)
for _, service := range services {
resources = append(resources, Resource{
Type: "service",
Value: service.Domain,
Label: service.Name,
Domain: service.Domain,
})
}
return resources
}
func servicesForSelectedProfile(profiles []AccessProfile, selectedProfileID *uuid.UUID) []AccessService {
if selectedProfileID == nil {
if len(profiles) == 0 {
return nil
}
return profiles[0].Services
}
for _, profile := range profiles {
if profile.ID == *selectedProfileID {
return profile.Services
}
}
return nil
}
func proxyRoutesForServices(services []AccessService) []string {
seen := make(map[string]struct{}, len(services))
routes := make([]string, 0, len(services))
for _, service := range services {
route := dnsServerRoute(effectiveServiceProxyIP(service.ProxyIP))
if route == "" {
continue
}
if _, ok := seen[route]; ok {
continue
}
seen[route] = struct{}{}
routes = append(routes, route)
}
return routes
}
func effectiveServiceProxyIP(proxyIP string) string {
override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP"))
if override != "" {
return override
}
return proxyIP
}
func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string {
seen := make(map[string]struct{}, len(destinations)+len(dnsServers))
merged := make([]string, 0, len(destinations)+len(dnsServers))
for _, destination := range destinations {
destination = strings.TrimSpace(destination)
if destination == "" {
continue
}
if _, exists := seen[destination]; exists {
continue
}
seen[destination] = struct{}{}
merged = append(merged, destination)
}
for _, dnsServer := range dnsServers {
route := dnsServerRoute(dnsServer)
if route == "" {
continue
}
if _, exists := seen[route]; exists {
continue
}
seen[route] = struct{}{}
merged = append(merged, route)
}
return merged
}
func dnsServerRoute(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if strings.Contains(value, "/") {
return value
}
return value + "/32"
}
func dnsServersForProfile(base []string, services []AccessService) []string {
if len(services) > 0 {
if override := parseEnvList("NEXAVPN_CLIENT_DNS_SERVERS"); len(override) > 0 {
return override
}
return dedupeList(base)
}
if override := parseEnvList("DEFAULT_DNS_SERVERS"); len(override) > 0 {
return override
}
return dedupeList(base)
}
func parseEnvList(key string) []string {
return parseCommaList(os.Getenv(key))
}
func parseCommaList(raw string) []string {
seen := make(map[string]struct{})
values := make([]string, 0)
for _, part := range strings.Split(raw, ",") {
value := strings.TrimSpace(part)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
values = append(values, value)
}
return values
}
func dedupeList(values []string) []string {
return parseCommaList(strings.Join(values, ","))
}