From e3bd6d3b9636014926ee9dc628f3dc24ef998efd Mon Sep 17 00:00:00 2001 From: nessi Date: Wed, 18 Mar 2026 08:48:08 +0100 Subject: [PATCH] feat: add DNS server routes to WireGuard profiles and gateway firewall rules Add mergeProfileAllowedIPs function to combine policy destinations with DNS server routes in device enrollment and rotation. Add dnsServerRoute helper to convert DNS server IPs to /32 CIDR notation. Update BuildSyncBundle query to include gateway DNS servers in peer data. Add DNSServers field to wireguard.Peer struct. Update gateway nftables configuration to allow UDP/TCP port 53 traffic from assigned IPs to DNS servers before --- backend/internal/device/service.go | 51 ++++++++++++++++++++++++-- backend/internal/gateway/repository.go | 8 ++-- backend/internal/wireguard/types.go | 1 + deploy/scripts/gateway-entrypoint.sh | 4 ++ 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/backend/internal/device/service.go b/backend/internal/device/service.go index 2e89377..da60f24 100644 --- a/backend/internal/device/service.go +++ b/backend/internal/device/service.go @@ -2,6 +2,7 @@ package device import ( "context" + "strings" "github.com/google/uuid" @@ -61,11 +62,12 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ if len(destinations) == 0 { destinations = []string{"172.16.10.0/24"} } + profileAllowedIPs := mergeProfileAllowedIPs(destinations, selectedGateway.DNSServers) enrollment.Peer = PeerView{ AssignedIP: assignedIP, DNSServers: selectedGateway.DNSServers, - AllowedIPs: destinations, + AllowedIPs: profileAllowedIPs, Gateway: GatewayView{ ID: selectedGateway.ID, Name: selectedGateway.Name, @@ -90,7 +92,7 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ DNSServers: selectedGateway.DNSServers, ServerPublicKey: selectedGateway.PublicKey, ServerEndpoint: selectedGateway.Endpoint, - AllowedIPs: destinations, + AllowedIPs: profileAllowedIPs, PersistentKeepal: 25, }), } @@ -149,6 +151,7 @@ func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error { } func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse { + profileAllowedIPs := mergeProfileAllowedIPs(enrollment.Peer.AllowedIPs, enrollment.Peer.DNSServers) enrollment.Profile = ProfileView{ Format: "wireguard", Content: profile.BuildWireGuardConfig(profile.BuildInput{ @@ -157,9 +160,51 @@ func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse { DNSServers: enrollment.Peer.DNSServers, ServerPublicKey: enrollment.Peer.Gateway.PublicKey, ServerEndpoint: enrollment.Peer.Gateway.Endpoint, - AllowedIPs: enrollment.Peer.AllowedIPs, + AllowedIPs: profileAllowedIPs, PersistentKeepal: 25, }), } return enrollment } + +func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string { + seen := make(map[string]struct{}, len(destinations)+len(dnsServers)) + merged := make([]string, 0, len(destinations)+len(dnsServers)) + + for _, destination := range destinations { + destination = strings.TrimSpace(destination) + if destination == "" { + continue + } + if _, exists := seen[destination]; exists { + continue + } + seen[destination] = struct{}{} + merged = append(merged, destination) + } + + for _, dnsServer := range dnsServers { + route := dnsServerRoute(dnsServer) + if route == "" { + continue + } + if _, exists := seen[route]; exists { + continue + } + seen[route] = struct{}{} + merged = append(merged, route) + } + + return merged +} + +func dnsServerRoute(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if strings.Contains(value, "/") { + return value + } + return value + "/32" +} diff --git a/backend/internal/gateway/repository.go b/backend/internal/gateway/repository.go index e1b8e97..2b4c33a 100644 --- a/backend/internal/gateway/repository.go +++ b/backend/internal/gateway/repository.go @@ -91,9 +91,11 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) d.id, wp.public_key, set_masklen(wp.assigned_ip, 32)::text, - coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}') + coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'), + coalesce(g.dns_servers, '{}')::text[] from devices d join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null + join gateways g on g.id = d.gateway_id left join group_memberships gm on gm.user_id = d.user_id left join policy_targets pt on ( (pt.target_type = 'device' and pt.target_id = d.id) or @@ -102,7 +104,7 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) ) left join policy_destinations pd on pd.policy_id = pt.policy_id where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active' - group by d.id, wp.public_key, wp.assigned_ip + group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers `, gatewayID) if err != nil { return wireguard.GatewayBundle{}, err @@ -112,7 +114,7 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) for rows.Next() { var peer wireguard.Peer var deviceID uuid.UUID - if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations); err != nil { + if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &peer.DNSServers); err != nil { return wireguard.GatewayBundle{}, err } peer.DeviceID = deviceID.String() diff --git a/backend/internal/wireguard/types.go b/backend/internal/wireguard/types.go index b4c208b..7160d99 100644 --- a/backend/internal/wireguard/types.go +++ b/backend/internal/wireguard/types.go @@ -5,6 +5,7 @@ type Peer struct { PublicKey string `json:"public_key"` AssignedIP string `json:"assigned_ip"` AllowedDestinations []string `json:"allowed_destinations"` + DNSServers []string `json:"dns_servers"` } type GatewayBundle struct { diff --git a/deploy/scripts/gateway-entrypoint.sh b/deploy/scripts/gateway-entrypoint.sh index 586946a..dc48a08 100644 --- a/deploy/scripts/gateway-entrypoint.sh +++ b/deploy/scripts/gateway-entrypoint.sh @@ -114,6 +114,10 @@ EOF jq -c '.peers[]?' "${STATE_JSON}" | while read -r peer; do ASSIGNED_IP=$(printf '%s' "${peer}" | jq -r '.assigned_ip') + printf '%s' "${peer}" | jq -r '.dns_servers[]?' | while read -r dns_server; do + echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} udp dport 53 accept" + echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} tcp dport 53 accept" + done printf '%s' "${peer}" | jq -r '.allowed_destinations[]?' | while read -r destination; do echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${destination} accept" done