feat: add environment-based DNS server override support with service-aware fallback logic

Add dnsServersForProfile and dnsServersForPeer helpers with conditional DNS server selection based on service presence. Use NEXAVPN_CLIENT_DNS_SERVERS override when services are configured, otherwise fall back to DEFAULT_DNS_SERVERS or gateway base DNS servers.

Replace direct gateway DNS server usage in Enroll and applyCurrentPolicy with profileDNSServers variable. Update BuildSyncBundle to scan gateway DNS servers separately
This commit is contained in:
2026-03-20 08:30:35 +01:00
parent b199b58840
commit 784971f111
2 changed files with 86 additions and 4 deletions

View File

@@ -72,14 +72,15 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
selectedDestinations = destinations selectedDestinations = destinations
} }
selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID) selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
profileDNSServers := dnsServersForProfile(selectedGateway.DNSServers, selectedServices)
profileAllowedIPs := mergeProfileAllowedIPs( profileAllowedIPs := mergeProfileAllowedIPs(
append(selectedDestinations, proxyRoutesForServices(selectedServices)...), append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
selectedGateway.DNSServers, profileDNSServers,
) )
enrollment.Peer = PeerView{ enrollment.Peer = PeerView{
AssignedIP: assignedIP, AssignedIP: assignedIP,
DNSServers: selectedGateway.DNSServers, DNSServers: profileDNSServers,
AllowedIPs: profileAllowedIPs, AllowedIPs: profileAllowedIPs,
Gateway: GatewayView{ Gateway: GatewayView{
ID: selectedGateway.ID, ID: selectedGateway.ID,
@@ -98,7 +99,7 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
Content: profile.BuildWireGuardConfig(profile.BuildInput{ Content: profile.BuildWireGuardConfig(profile.BuildInput{
PrivateKey: privateKeyPlaceholder, PrivateKey: privateKeyPlaceholder,
Address: assignedIP, Address: assignedIP,
DNSServers: selectedGateway.DNSServers, DNSServers: profileDNSServers,
ServerPublicKey: selectedGateway.PublicKey, ServerPublicKey: selectedGateway.PublicKey,
ServerEndpoint: selectedGateway.Endpoint, ServerEndpoint: selectedGateway.Endpoint,
AllowedIPs: profileAllowedIPs, AllowedIPs: profileAllowedIPs,
@@ -223,6 +224,7 @@ func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentR
if len(selectedDestinations) == 0 && len(selectedServices) == 0 { if len(selectedDestinations) == 0 && len(selectedServices) == 0 {
selectedDestinations = []string{"172.16.10.0/24"} selectedDestinations = []string{"172.16.10.0/24"}
} }
enrollment.Peer.DNSServers = dnsServersForProfile(enrollment.Peer.DNSServers, selectedServices)
enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices) enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
enrollment.AvailableProfiles = availableProfiles enrollment.AvailableProfiles = availableProfiles
@@ -394,3 +396,42 @@ func dnsServerRoute(value string) string {
} }
return value + "/32" return value + "/32"
} }
func dnsServersForProfile(base []string, services []AccessService) []string {
if len(services) > 0 {
if override := parseEnvList("NEXAVPN_CLIENT_DNS_SERVERS"); len(override) > 0 {
return override
}
return dedupeList(base)
}
if override := parseEnvList("DEFAULT_DNS_SERVERS"); len(override) > 0 {
return override
}
return dedupeList(base)
}
func parseEnvList(key string) []string {
return parseCommaList(os.Getenv(key))
}
func parseCommaList(raw string) []string {
seen := make(map[string]struct{})
values := make([]string, 0)
for _, part := range strings.Split(raw, ",") {
value := strings.TrimSpace(part)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
values = append(values, value)
}
return values
}
func dedupeList(values []string) []string {
return parseCommaList(strings.Join(values, ","))
}

View File

@@ -128,7 +128,8 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
var peer wireguard.Peer var peer wireguard.Peer
var deviceID uuid.UUID var deviceID uuid.UUID
var selectedProfileID *string var selectedProfileID *string
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &peer.DNSServers, &selectedProfileID); err != nil { var gatewayDNSServers []string
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &gatewayDNSServers, &selectedProfileID); err != nil {
return wireguard.GatewayBundle{}, err return wireguard.GatewayBundle{}, err
} }
peer.DeviceID = deviceID.String() peer.DeviceID = deviceID.String()
@@ -136,6 +137,7 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
if err != nil { if err != nil {
return wireguard.GatewayBundle{}, err return wireguard.GatewayBundle{}, err
} }
peer.DNSServers = dnsServersForPeer(gatewayDNSServers, services)
peer.AllowedServices = services peer.AllowedServices = services
bundle.Peers = append(bundle.Peers, peer) bundle.Peers = append(bundle.Peers, peer)
} }
@@ -220,6 +222,45 @@ func (r *PGRepository) ListServiceDNSRecords(ctx context.Context) ([]ServiceDNSR
return items, rows.Err() return items, rows.Err()
} }
func dnsServersForPeer(base []string, services []wireguard.AllowedService) []string {
if len(services) > 0 {
if override := parseEnvList("NEXAVPN_CLIENT_DNS_SERVERS"); len(override) > 0 {
return override
}
return dedupeList(base)
}
if override := parseEnvList("DEFAULT_DNS_SERVERS"); len(override) > 0 {
return override
}
return dedupeList(base)
}
func parseEnvList(key string) []string {
return parseCommaList(os.Getenv(key))
}
func parseCommaList(raw string) []string {
seen := make(map[string]struct{})
values := make([]string, 0)
for _, part := range strings.Split(raw, ",") {
value := strings.TrimSpace(part)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
values = append(values, value)
}
return values
}
func dedupeList(values []string) []string {
return parseCommaList(strings.Join(values, ","))
}
func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) { func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) {
row := r.db.QueryRow(ctx, ` row := r.db.QueryRow(ctx, `
update gateways update gateways