feat: add DNS server routes to WireGuard profiles and gateway firewall rules
Add mergeProfileAllowedIPs function to combine policy destinations with DNS server routes in device enrollment and rotation. Add dnsServerRoute helper to convert DNS server IPs to /32 CIDR notation. Update BuildSyncBundle query to include gateway DNS servers in peer data. Add DNSServers field to wireguard.Peer struct. Update gateway nftables configuration to allow UDP/TCP port 53 traffic from assigned IPs to DNS servers before
This commit is contained in:
@@ -2,6 +2,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
@@ -61,11 +62,12 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
|
|||||||
if len(destinations) == 0 {
|
if len(destinations) == 0 {
|
||||||
destinations = []string{"172.16.10.0/24"}
|
destinations = []string{"172.16.10.0/24"}
|
||||||
}
|
}
|
||||||
|
profileAllowedIPs := mergeProfileAllowedIPs(destinations, selectedGateway.DNSServers)
|
||||||
|
|
||||||
enrollment.Peer = PeerView{
|
enrollment.Peer = PeerView{
|
||||||
AssignedIP: assignedIP,
|
AssignedIP: assignedIP,
|
||||||
DNSServers: selectedGateway.DNSServers,
|
DNSServers: selectedGateway.DNSServers,
|
||||||
AllowedIPs: destinations,
|
AllowedIPs: profileAllowedIPs,
|
||||||
Gateway: GatewayView{
|
Gateway: GatewayView{
|
||||||
ID: selectedGateway.ID,
|
ID: selectedGateway.ID,
|
||||||
Name: selectedGateway.Name,
|
Name: selectedGateway.Name,
|
||||||
@@ -90,7 +92,7 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
|
|||||||
DNSServers: selectedGateway.DNSServers,
|
DNSServers: selectedGateway.DNSServers,
|
||||||
ServerPublicKey: selectedGateway.PublicKey,
|
ServerPublicKey: selectedGateway.PublicKey,
|
||||||
ServerEndpoint: selectedGateway.Endpoint,
|
ServerEndpoint: selectedGateway.Endpoint,
|
||||||
AllowedIPs: destinations,
|
AllowedIPs: profileAllowedIPs,
|
||||||
PersistentKeepal: 25,
|
PersistentKeepal: 25,
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
@@ -149,6 +151,7 @@ func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
|
func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
|
||||||
|
profileAllowedIPs := mergeProfileAllowedIPs(enrollment.Peer.AllowedIPs, enrollment.Peer.DNSServers)
|
||||||
enrollment.Profile = ProfileView{
|
enrollment.Profile = ProfileView{
|
||||||
Format: "wireguard",
|
Format: "wireguard",
|
||||||
Content: profile.BuildWireGuardConfig(profile.BuildInput{
|
Content: profile.BuildWireGuardConfig(profile.BuildInput{
|
||||||
@@ -157,9 +160,51 @@ func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
|
|||||||
DNSServers: enrollment.Peer.DNSServers,
|
DNSServers: enrollment.Peer.DNSServers,
|
||||||
ServerPublicKey: enrollment.Peer.Gateway.PublicKey,
|
ServerPublicKey: enrollment.Peer.Gateway.PublicKey,
|
||||||
ServerEndpoint: enrollment.Peer.Gateway.Endpoint,
|
ServerEndpoint: enrollment.Peer.Gateway.Endpoint,
|
||||||
AllowedIPs: enrollment.Peer.AllowedIPs,
|
AllowedIPs: profileAllowedIPs,
|
||||||
PersistentKeepal: 25,
|
PersistentKeepal: 25,
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
return enrollment
|
return enrollment
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string {
|
||||||
|
seen := make(map[string]struct{}, len(destinations)+len(dnsServers))
|
||||||
|
merged := make([]string, 0, len(destinations)+len(dnsServers))
|
||||||
|
|
||||||
|
for _, destination := range destinations {
|
||||||
|
destination = strings.TrimSpace(destination)
|
||||||
|
if destination == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[destination]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[destination] = struct{}{}
|
||||||
|
merged = append(merged, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dnsServer := range dnsServers {
|
||||||
|
route := dnsServerRoute(dnsServer)
|
||||||
|
if route == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[route]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[route] = struct{}{}
|
||||||
|
merged = append(merged, route)
|
||||||
|
}
|
||||||
|
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
func dnsServerRoute(value string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.Contains(value, "/") {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return value + "/32"
|
||||||
|
}
|
||||||
|
|||||||
@@ -91,9 +91,11 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
|
|||||||
d.id,
|
d.id,
|
||||||
wp.public_key,
|
wp.public_key,
|
||||||
set_masklen(wp.assigned_ip, 32)::text,
|
set_masklen(wp.assigned_ip, 32)::text,
|
||||||
coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}')
|
coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'),
|
||||||
|
coalesce(g.dns_servers, '{}')::text[]
|
||||||
from devices d
|
from devices d
|
||||||
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
||||||
|
join gateways g on g.id = d.gateway_id
|
||||||
left join group_memberships gm on gm.user_id = d.user_id
|
left join group_memberships gm on gm.user_id = d.user_id
|
||||||
left join policy_targets pt on (
|
left join policy_targets pt on (
|
||||||
(pt.target_type = 'device' and pt.target_id = d.id) or
|
(pt.target_type = 'device' and pt.target_id = d.id) or
|
||||||
@@ -102,7 +104,7 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
|
|||||||
)
|
)
|
||||||
left join policy_destinations pd on pd.policy_id = pt.policy_id
|
left join policy_destinations pd on pd.policy_id = pt.policy_id
|
||||||
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
|
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
|
||||||
group by d.id, wp.public_key, wp.assigned_ip
|
group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers
|
||||||
`, gatewayID)
|
`, gatewayID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wireguard.GatewayBundle{}, err
|
return wireguard.GatewayBundle{}, err
|
||||||
@@ -112,7 +114,7 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var peer wireguard.Peer
|
var peer wireguard.Peer
|
||||||
var deviceID uuid.UUID
|
var deviceID uuid.UUID
|
||||||
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations); err != nil {
|
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &peer.DNSServers); err != nil {
|
||||||
return wireguard.GatewayBundle{}, err
|
return wireguard.GatewayBundle{}, err
|
||||||
}
|
}
|
||||||
peer.DeviceID = deviceID.String()
|
peer.DeviceID = deviceID.String()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ type Peer struct {
|
|||||||
PublicKey string `json:"public_key"`
|
PublicKey string `json:"public_key"`
|
||||||
AssignedIP string `json:"assigned_ip"`
|
AssignedIP string `json:"assigned_ip"`
|
||||||
AllowedDestinations []string `json:"allowed_destinations"`
|
AllowedDestinations []string `json:"allowed_destinations"`
|
||||||
|
DNSServers []string `json:"dns_servers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GatewayBundle struct {
|
type GatewayBundle struct {
|
||||||
|
|||||||
@@ -114,6 +114,10 @@ EOF
|
|||||||
|
|
||||||
jq -c '.peers[]?' "${STATE_JSON}" | while read -r peer; do
|
jq -c '.peers[]?' "${STATE_JSON}" | while read -r peer; do
|
||||||
ASSIGNED_IP=$(printf '%s' "${peer}" | jq -r '.assigned_ip')
|
ASSIGNED_IP=$(printf '%s' "${peer}" | jq -r '.assigned_ip')
|
||||||
|
printf '%s' "${peer}" | jq -r '.dns_servers[]?' | while read -r dns_server; do
|
||||||
|
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} udp dport 53 accept"
|
||||||
|
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} tcp dport 53 accept"
|
||||||
|
done
|
||||||
printf '%s' "${peer}" | jq -r '.allowed_destinations[]?' | while read -r destination; do
|
printf '%s' "${peer}" | jq -r '.allowed_destinations[]?' | while read -r destination; do
|
||||||
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${destination} accept"
|
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${destination} accept"
|
||||||
done
|
done
|
||||||
|
|||||||
Reference in New Issue
Block a user