package gateway import ( "context" "encoding/json" "net/netip" "os" "strings" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "nexavpn/backend/internal/wireguard" ) type Repository interface { List(ctx context.Context) ([]Gateway, error) FirstActive(ctx context.Context) (Gateway, error) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) ListServiceDNSRecords(ctx context.Context) ([]ServiceDNSRecord, error) StoreTelemetry(ctx context.Context, gatewayID uuid.UUID, snapshot TelemetrySnapshot) error Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) UpsertByName(ctx context.Context, input BootstrapRequest) (Gateway, error) } type PGRepository struct { db *pgxpool.Pool } func NewPGRepository(db *pgxpool.Pool) *PGRepository { return &PGRepository{db: db} } func (r *PGRepository) List(ctx context.Context) ([]Gateway, error) { rows, err := r.db.Query(ctx, ` select id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active from gateways where deleted_at is null order by created_at desc `) if err != nil { return nil, err } defer rows.Close() var items []Gateway for rows.Next() { var item Gateway if err := rows.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive); err != nil { return nil, err } items = append(items, item) } return items, rows.Err() } func (r *PGRepository) FirstActive(ctx context.Context) (Gateway, error) { row := r.db.QueryRow(ctx, ` select id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active from gateways where deleted_at is null and is_active = true order by created_at asc limit 1 `) var item Gateway err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive) return item, err } func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) { var bundle wireguard.GatewayBundle bundle.GatewayID = gatewayID.String() bundle.Revision = 1 row := r.db.QueryRow(ctx, ` select vpn_cidr::text, listen_port from gateways where id = $1 and deleted_at is null `, gatewayID) var vpnCIDR string if err := row.Scan(&vpnCIDR, &bundle.Interface.ListenPort); err != nil { return wireguard.GatewayBundle{}, err } interfaceAddress, err := gatewayInterfaceAddress(vpnCIDR) if err != nil { return wireguard.GatewayBundle{}, err } bundle.Interface.Address = interfaceAddress bundle.Interface.NetworkCIDR = vpnCIDR rows, err := r.db.Query(ctx, ` select d.id, wp.public_key, set_masklen(wp.assigned_ip, 32)::text, coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'), coalesce(g.dns_servers, '{}')::text[], s.value->>'profile_id' from devices d join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null join gateways g on g.id = d.gateway_id left join settings s on s.category = 'device_access_profile' and s.key = d.id::text left join group_memberships gm on gm.user_id = d.user_id left join policy_targets pt on ( (pt.target_type = 'device' and pt.target_id = d.id) or (pt.target_type = 'user' and pt.target_id = d.user_id) or (pt.target_type = 'group' and pt.target_id = gm.group_id) ) left join policies p on p.id = pt.policy_id and p.deleted_at is null and p.is_active = true and p.effect = 'allow' left join policy_destinations pd on pd.policy_id = p.id and ( s.value->>'profile_id' is null or p.id::text = s.value->>'profile_id' ) where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active' group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers, s.value `, gatewayID) if err != nil { return wireguard.GatewayBundle{}, err } defer rows.Close() for rows.Next() { var peer wireguard.Peer var deviceID uuid.UUID var selectedProfileID *string var gatewayDNSServers []string if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &gatewayDNSServers, &selectedProfileID); err != nil { return wireguard.GatewayBundle{}, err } peer.DeviceID = deviceID.String() services, err := r.listAllowedServices(ctx, deviceID, selectedProfileID) if err != nil { return wireguard.GatewayBundle{}, err } peer.DNSServers = dnsServersForPeer(gatewayDNSServers, services) peer.AllowedServices = services bundle.Peers = append(bundle.Peers, peer) } return bundle, rows.Err() } func (r *PGRepository) listAllowedServices(ctx context.Context, deviceID uuid.UUID, selectedProfileID *string) ([]wireguard.AllowedService, error) { rows, err := r.db.Query(ctx, ` select distinct s.name, s.domain, host(s.upstream_ip), host(s.proxy_ip), s.ports from devices d left join group_memberships gm on gm.user_id = d.user_id join policy_targets pt on ( (pt.target_type = 'device' and pt.target_id = d.id) or (pt.target_type = 'user' and pt.target_id = d.user_id) or (pt.target_type = 'group' and pt.target_id = gm.group_id) ) join policies p on p.id = pt.policy_id and p.deleted_at is null and p.is_active = true and p.effect = 'allow' join policy_services ps on ps.policy_id = p.id join services s on s.id = ps.service_id and s.deleted_at is null and s.is_active = true where d.id = $1 and ($2::text is null or p.id::text = $2::text) order by s.name asc `, deviceID, selectedProfileID) if err != nil { return nil, err } defer rows.Close() var items []wireguard.AllowedService for rows.Next() { var item wireguard.AllowedService if err := rows.Scan(&item.Name, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports); err != nil { return nil, err } item.AccessProxyIP = effectiveAccessProxyIP(item.ProxyIP) items = append(items, item) } return items, rows.Err() } func effectiveAccessProxyIP(proxyIP string) string { override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP")) if override != "" { return override } return proxyIP } func (r *PGRepository) ListServiceDNSRecords(ctx context.Context) ([]ServiceDNSRecord, error) { rows, err := r.db.Query(ctx, ` select distinct s.domain, host(s.proxy_ip) from services s where s.deleted_at is null and s.is_active = true order by s.domain asc `) if err != nil { return nil, err } defer rows.Close() var items []ServiceDNSRecord for rows.Next() { var item ServiceDNSRecord var proxyIP string if err := rows.Scan(&item.Domain, &proxyIP); err != nil { return nil, err } item.TargetIP = effectiveAccessProxyIP(proxyIP) items = append(items, item) } return items, rows.Err() } func dnsServersForPeer(base []string, services []wireguard.AllowedService) []string { if len(services) > 0 { if override := parseEnvList("NEXAVPN_CLIENT_DNS_SERVERS"); len(override) > 0 { return override } return dedupeList(base) } if override := parseEnvList("DEFAULT_DNS_SERVERS"); len(override) > 0 { return override } return dedupeList(base) } func parseEnvList(key string) []string { return parseCommaList(os.Getenv(key)) } func parseCommaList(raw string) []string { seen := make(map[string]struct{}) values := make([]string, 0) for _, part := range strings.Split(raw, ",") { value := strings.TrimSpace(part) if value == "" { continue } if _, ok := seen[value]; ok { continue } seen[value] = struct{}{} values = append(values, value) } return values } func dedupeList(values []string) []string { return parseCommaList(strings.Join(values, ",")) } func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) { row := r.db.QueryRow(ctx, ` update gateways set endpoint = $2, public_key = $3, listen_port = $4, vpn_cidr = $5::cidr, dns_servers = $6::text[], is_active = $7, updated_at = now() where id = $1 returning id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active `, gatewayID, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers, input.IsActive) var item Gateway err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive) return item, err } func (r *PGRepository) StoreTelemetry(ctx context.Context, gatewayID uuid.UUID, snapshot TelemetrySnapshot) error { payload, err := json.Marshal(snapshot) if err != nil { return err } _, err = r.db.Exec(ctx, ` insert into settings (category, key, value, updated_at) values ('gateway_runtime', $1, $2::jsonb, now()) on conflict (category, key) do update set value = excluded.value, updated_at = now() `, gatewayID.String(), string(payload)) return err } func (r *PGRepository) UpsertByName(ctx context.Context, input BootstrapRequest) (Gateway, error) { row := r.db.QueryRow(ctx, ` insert into gateways (id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active) values ($1, $2, $3, $4, $5, $6::cidr, $7::text[], true) on conflict (name) do update set endpoint = excluded.endpoint, public_key = excluded.public_key, listen_port = excluded.listen_port, vpn_cidr = excluded.vpn_cidr, dns_servers = excluded.dns_servers, is_active = true, updated_at = now() returning id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active `, uuid.New(), input.Name, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers) var item Gateway err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive) return item, err } func gatewayInterfaceAddress(cidr string) (string, error) { prefix, err := netip.ParsePrefix(cidr) if err != nil { return "", err } return prefix.Addr().Next().String() + "/" + intToString(prefix.Bits()), nil } func intToString(value int) string { if value == 0 { return "0" } var digits [20]byte index := len(digits) for value > 0 { index-- digits[index] = byte('0' + value%10) value /= 10 } return string(digits[index:]) }