Add dnsServersForProfile and dnsServersForPeer helpers with conditional DNS server selection based on service presence. Use NEXAVPN_CLIENT_DNS_SERVERS override when services are configured, otherwise fall back to DEFAULT_DNS_SERVERS or gateway base DNS servers. Replace direct gateway DNS server usage in Enroll and applyCurrentPolicy with profileDNSServers variable. Update BuildSyncBundle to scan gateway DNS servers separately
342 lines
10 KiB
Go
342 lines
10 KiB
Go
package gateway
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/netip"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"nexavpn/backend/internal/wireguard"
|
|
)
|
|
|
|
type Repository interface {
|
|
List(ctx context.Context) ([]Gateway, error)
|
|
FirstActive(ctx context.Context) (Gateway, error)
|
|
BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error)
|
|
ListServiceDNSRecords(ctx context.Context) ([]ServiceDNSRecord, error)
|
|
StoreTelemetry(ctx context.Context, gatewayID uuid.UUID, snapshot TelemetrySnapshot) error
|
|
Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error)
|
|
UpsertByName(ctx context.Context, input BootstrapRequest) (Gateway, error)
|
|
}
|
|
|
|
type PGRepository struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
|
return &PGRepository{db: db}
|
|
}
|
|
|
|
func (r *PGRepository) List(ctx context.Context) ([]Gateway, error) {
|
|
rows, err := r.db.Query(ctx, `
|
|
select id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
|
|
from gateways
|
|
where deleted_at is null
|
|
order by created_at desc
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []Gateway
|
|
for rows.Next() {
|
|
var item Gateway
|
|
if err := rows.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive); err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (r *PGRepository) FirstActive(ctx context.Context) (Gateway, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
select id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
|
|
from gateways
|
|
where deleted_at is null and is_active = true
|
|
order by created_at asc
|
|
limit 1
|
|
`)
|
|
|
|
var item Gateway
|
|
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
|
|
return item, err
|
|
}
|
|
|
|
func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) {
|
|
var bundle wireguard.GatewayBundle
|
|
bundle.GatewayID = gatewayID.String()
|
|
bundle.Revision = 1
|
|
|
|
row := r.db.QueryRow(ctx, `
|
|
select vpn_cidr::text, listen_port
|
|
from gateways
|
|
where id = $1 and deleted_at is null
|
|
`, gatewayID)
|
|
var vpnCIDR string
|
|
if err := row.Scan(&vpnCIDR, &bundle.Interface.ListenPort); err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
interfaceAddress, err := gatewayInterfaceAddress(vpnCIDR)
|
|
if err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
bundle.Interface.Address = interfaceAddress
|
|
bundle.Interface.NetworkCIDR = vpnCIDR
|
|
|
|
rows, err := r.db.Query(ctx, `
|
|
select
|
|
d.id,
|
|
wp.public_key,
|
|
set_masklen(wp.assigned_ip, 32)::text,
|
|
coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'),
|
|
coalesce(g.dns_servers, '{}')::text[],
|
|
s.value->>'profile_id'
|
|
from devices d
|
|
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
|
join gateways g on g.id = d.gateway_id
|
|
left join settings s on s.category = 'device_access_profile' and s.key = d.id::text
|
|
left join group_memberships gm on gm.user_id = d.user_id
|
|
left join policy_targets pt on (
|
|
(pt.target_type = 'device' and pt.target_id = d.id) or
|
|
(pt.target_type = 'user' and pt.target_id = d.user_id) or
|
|
(pt.target_type = 'group' and pt.target_id = gm.group_id)
|
|
)
|
|
left join policies p on p.id = pt.policy_id
|
|
and p.deleted_at is null
|
|
and p.is_active = true
|
|
and p.effect = 'allow'
|
|
left join policy_destinations pd on pd.policy_id = p.id
|
|
and (
|
|
s.value->>'profile_id' is null
|
|
or p.id::text = s.value->>'profile_id'
|
|
)
|
|
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
|
|
group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers, s.value
|
|
`, gatewayID)
|
|
if err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var peer wireguard.Peer
|
|
var deviceID uuid.UUID
|
|
var selectedProfileID *string
|
|
var gatewayDNSServers []string
|
|
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &gatewayDNSServers, &selectedProfileID); err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
peer.DeviceID = deviceID.String()
|
|
services, err := r.listAllowedServices(ctx, deviceID, selectedProfileID)
|
|
if err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
peer.DNSServers = dnsServersForPeer(gatewayDNSServers, services)
|
|
peer.AllowedServices = services
|
|
bundle.Peers = append(bundle.Peers, peer)
|
|
}
|
|
|
|
return bundle, rows.Err()
|
|
}
|
|
|
|
func (r *PGRepository) listAllowedServices(ctx context.Context, deviceID uuid.UUID, selectedProfileID *string) ([]wireguard.AllowedService, error) {
|
|
rows, err := r.db.Query(ctx, `
|
|
select distinct
|
|
s.name,
|
|
s.domain,
|
|
host(s.upstream_ip),
|
|
host(s.proxy_ip),
|
|
s.ports
|
|
from devices d
|
|
left join group_memberships gm on gm.user_id = d.user_id
|
|
join policy_targets pt on (
|
|
(pt.target_type = 'device' and pt.target_id = d.id) or
|
|
(pt.target_type = 'user' and pt.target_id = d.user_id) or
|
|
(pt.target_type = 'group' and pt.target_id = gm.group_id)
|
|
)
|
|
join policies p on p.id = pt.policy_id
|
|
and p.deleted_at is null
|
|
and p.is_active = true
|
|
and p.effect = 'allow'
|
|
join policy_services ps on ps.policy_id = p.id
|
|
join services s on s.id = ps.service_id and s.deleted_at is null and s.is_active = true
|
|
where d.id = $1
|
|
and ($2::text is null or p.id::text = $2::text)
|
|
order by s.name asc
|
|
`, deviceID, selectedProfileID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []wireguard.AllowedService
|
|
for rows.Next() {
|
|
var item wireguard.AllowedService
|
|
if err := rows.Scan(&item.Name, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports); err != nil {
|
|
return nil, err
|
|
}
|
|
item.AccessProxyIP = effectiveAccessProxyIP(item.ProxyIP)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func effectiveAccessProxyIP(proxyIP string) string {
|
|
override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP"))
|
|
if override != "" {
|
|
return override
|
|
}
|
|
return proxyIP
|
|
}
|
|
|
|
func (r *PGRepository) ListServiceDNSRecords(ctx context.Context) ([]ServiceDNSRecord, error) {
|
|
rows, err := r.db.Query(ctx, `
|
|
select distinct
|
|
s.domain,
|
|
host(s.proxy_ip)
|
|
from services s
|
|
where s.deleted_at is null and s.is_active = true
|
|
order by s.domain asc
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []ServiceDNSRecord
|
|
for rows.Next() {
|
|
var item ServiceDNSRecord
|
|
var proxyIP string
|
|
if err := rows.Scan(&item.Domain, &proxyIP); err != nil {
|
|
return nil, err
|
|
}
|
|
item.TargetIP = effectiveAccessProxyIP(proxyIP)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func dnsServersForPeer(base []string, services []wireguard.AllowedService) []string {
|
|
if len(services) > 0 {
|
|
if override := parseEnvList("NEXAVPN_CLIENT_DNS_SERVERS"); len(override) > 0 {
|
|
return override
|
|
}
|
|
return dedupeList(base)
|
|
}
|
|
|
|
if override := parseEnvList("DEFAULT_DNS_SERVERS"); len(override) > 0 {
|
|
return override
|
|
}
|
|
return dedupeList(base)
|
|
}
|
|
|
|
func parseEnvList(key string) []string {
|
|
return parseCommaList(os.Getenv(key))
|
|
}
|
|
|
|
func parseCommaList(raw string) []string {
|
|
seen := make(map[string]struct{})
|
|
values := make([]string, 0)
|
|
for _, part := range strings.Split(raw, ",") {
|
|
value := strings.TrimSpace(part)
|
|
if value == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[value]; ok {
|
|
continue
|
|
}
|
|
seen[value] = struct{}{}
|
|
values = append(values, value)
|
|
}
|
|
return values
|
|
}
|
|
|
|
func dedupeList(values []string) []string {
|
|
return parseCommaList(strings.Join(values, ","))
|
|
}
|
|
|
|
func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
update gateways
|
|
set endpoint = $2,
|
|
public_key = $3,
|
|
listen_port = $4,
|
|
vpn_cidr = $5::cidr,
|
|
dns_servers = $6::text[],
|
|
is_active = $7,
|
|
updated_at = now()
|
|
where id = $1
|
|
returning id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
|
|
`, gatewayID, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers, input.IsActive)
|
|
|
|
var item Gateway
|
|
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
|
|
return item, err
|
|
}
|
|
|
|
func (r *PGRepository) StoreTelemetry(ctx context.Context, gatewayID uuid.UUID, snapshot TelemetrySnapshot) error {
|
|
payload, err := json.Marshal(snapshot)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = r.db.Exec(ctx, `
|
|
insert into settings (category, key, value, updated_at)
|
|
values ('gateway_runtime', $1, $2::jsonb, now())
|
|
on conflict (category, key)
|
|
do update set value = excluded.value, updated_at = now()
|
|
`, gatewayID.String(), string(payload))
|
|
return err
|
|
}
|
|
|
|
func (r *PGRepository) UpsertByName(ctx context.Context, input BootstrapRequest) (Gateway, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
insert into gateways (id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active)
|
|
values ($1, $2, $3, $4, $5, $6::cidr, $7::text[], true)
|
|
on conflict (name)
|
|
do update set
|
|
endpoint = excluded.endpoint,
|
|
public_key = excluded.public_key,
|
|
listen_port = excluded.listen_port,
|
|
vpn_cidr = excluded.vpn_cidr,
|
|
dns_servers = excluded.dns_servers,
|
|
is_active = true,
|
|
updated_at = now()
|
|
returning id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
|
|
`, uuid.New(), input.Name, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers)
|
|
|
|
var item Gateway
|
|
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
|
|
return item, err
|
|
}
|
|
|
|
func gatewayInterfaceAddress(cidr string) (string, error) {
|
|
prefix, err := netip.ParsePrefix(cidr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return prefix.Addr().Next().String() + "/" + intToString(prefix.Bits()), nil
|
|
}
|
|
|
|
func intToString(value int) string {
|
|
if value == 0 {
|
|
return "0"
|
|
}
|
|
|
|
var digits [20]byte
|
|
index := len(digits)
|
|
for value > 0 {
|
|
index--
|
|
digits[index] = byte('0' + value%10)
|
|
value /= 10
|
|
}
|
|
return string(digits[index:])
|
|
}
|