diff --git a/backend/internal/device/service.go b/backend/internal/device/service.go index 2e89377..da60f24 100644 --- a/backend/internal/device/service.go +++ b/backend/internal/device/service.go @@ -2,6 +2,7 @@ package device import ( "context" + "strings" "github.com/google/uuid" @@ -61,11 +62,12 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ if len(destinations) == 0 { destinations = []string{"172.16.10.0/24"} } + profileAllowedIPs := mergeProfileAllowedIPs(destinations, selectedGateway.DNSServers) enrollment.Peer = PeerView{ AssignedIP: assignedIP, DNSServers: selectedGateway.DNSServers, - AllowedIPs: destinations, + AllowedIPs: profileAllowedIPs, Gateway: GatewayView{ ID: selectedGateway.ID, Name: selectedGateway.Name, @@ -90,7 +92,7 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ DNSServers: selectedGateway.DNSServers, ServerPublicKey: selectedGateway.PublicKey, ServerEndpoint: selectedGateway.Endpoint, - AllowedIPs: destinations, + AllowedIPs: profileAllowedIPs, PersistentKeepal: 25, }), } @@ -149,6 +151,7 @@ func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error { } func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse { + profileAllowedIPs := mergeProfileAllowedIPs(enrollment.Peer.AllowedIPs, enrollment.Peer.DNSServers) enrollment.Profile = ProfileView{ Format: "wireguard", Content: profile.BuildWireGuardConfig(profile.BuildInput{ @@ -157,9 +160,51 @@ func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse { DNSServers: enrollment.Peer.DNSServers, ServerPublicKey: enrollment.Peer.Gateway.PublicKey, ServerEndpoint: enrollment.Peer.Gateway.Endpoint, - AllowedIPs: enrollment.Peer.AllowedIPs, + AllowedIPs: profileAllowedIPs, PersistentKeepal: 25, }), } return enrollment } + +func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string { + seen := make(map[string]struct{}, len(destinations)+len(dnsServers)) + merged := make([]string, 0, len(destinations)+len(dnsServers)) + + for _, destination := range destinations { + destination = strings.TrimSpace(destination) + if destination == "" { + continue + } + if _, exists := seen[destination]; exists { + continue + } + seen[destination] = struct{}{} + merged = append(merged, destination) + } + + for _, dnsServer := range dnsServers { + route := dnsServerRoute(dnsServer) + if route == "" { + continue + } + if _, exists := seen[route]; exists { + continue + } + seen[route] = struct{}{} + merged = append(merged, route) + } + + return merged +} + +func dnsServerRoute(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if strings.Contains(value, "/") { + return value + } + return value + "/32" +} diff --git a/backend/internal/gateway/repository.go b/backend/internal/gateway/repository.go index e1b8e97..2b4c33a 100644 --- a/backend/internal/gateway/repository.go +++ b/backend/internal/gateway/repository.go @@ -91,9 +91,11 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) d.id, wp.public_key, set_masklen(wp.assigned_ip, 32)::text, - coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}') + coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'), + coalesce(g.dns_servers, '{}')::text[] from devices d join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null + join gateways g on g.id = d.gateway_id left join group_memberships gm on gm.user_id = d.user_id left join policy_targets pt on ( (pt.target_type = 'device' and pt.target_id = d.id) or @@ -102,7 +104,7 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) ) left join policy_destinations pd on pd.policy_id = pt.policy_id where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active' - group by d.id, wp.public_key, wp.assigned_ip + group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers `, gatewayID) if err != nil { return wireguard.GatewayBundle{}, err @@ -112,7 +114,7 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) for rows.Next() { var peer wireguard.Peer var deviceID uuid.UUID - if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations); err != nil { + if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &peer.DNSServers); err != nil { return wireguard.GatewayBundle{}, err } peer.DeviceID = deviceID.String() diff --git a/backend/internal/wireguard/types.go b/backend/internal/wireguard/types.go index b4c208b..7160d99 100644 --- a/backend/internal/wireguard/types.go +++ b/backend/internal/wireguard/types.go @@ -5,6 +5,7 @@ type Peer struct { PublicKey string `json:"public_key"` AssignedIP string `json:"assigned_ip"` AllowedDestinations []string `json:"allowed_destinations"` + DNSServers []string `json:"dns_servers"` } type GatewayBundle struct { diff --git a/deploy/scripts/gateway-entrypoint.sh b/deploy/scripts/gateway-entrypoint.sh index 586946a..dc48a08 100644 --- a/deploy/scripts/gateway-entrypoint.sh +++ b/deploy/scripts/gateway-entrypoint.sh @@ -114,6 +114,10 @@ EOF jq -c '.peers[]?' "${STATE_JSON}" | while read -r peer; do ASSIGNED_IP=$(printf '%s' "${peer}" | jq -r '.assigned_ip') + printf '%s' "${peer}" | jq -r '.dns_servers[]?' | while read -r dns_server; do + echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} udp dport 53 accept" + echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} tcp dport 53 accept" + done printf '%s' "${peer}" | jq -r '.allowed_destinations[]?' | while read -r destination; do echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${destination} accept" done