Files
NexaVPN/backend/internal/gateway/repository.go
nessi 3e2169f217 feat: add VPN DNS service with dynamic service catalog resolution and CoreDNS integration
Add ServiceDNSRecord type and gateway API endpoint to expose active service domain-to-IP mappings. Implement ListServiceDNSRecords repository method querying services table with proxy_ip resolution using effectiveAccessProxyIP helper.

Add vpn-dns microservice built on CoreDNS with periodic sync from backend API. Generate Corefile with configurable upstream DNS servers and hosts plugin for service overrides.
2026-03-18 13:30:34 +01:00

301 lines
9.1 KiB
Go

package gateway
import (
"context"
"encoding/json"
"net/netip"
"os"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"nexavpn/backend/internal/wireguard"
)
type Repository interface {
List(ctx context.Context) ([]Gateway, error)
FirstActive(ctx context.Context) (Gateway, error)
BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error)
ListServiceDNSRecords(ctx context.Context) ([]ServiceDNSRecord, error)
StoreTelemetry(ctx context.Context, gatewayID uuid.UUID, snapshot TelemetrySnapshot) error
Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error)
UpsertByName(ctx context.Context, input BootstrapRequest) (Gateway, error)
}
type PGRepository struct {
db *pgxpool.Pool
}
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
return &PGRepository{db: db}
}
func (r *PGRepository) List(ctx context.Context) ([]Gateway, error) {
rows, err := r.db.Query(ctx, `
select id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
from gateways
where deleted_at is null
order by created_at desc
`)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Gateway
for rows.Next() {
var item Gateway
if err := rows.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive); err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (r *PGRepository) FirstActive(ctx context.Context) (Gateway, error) {
row := r.db.QueryRow(ctx, `
select id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
from gateways
where deleted_at is null and is_active = true
order by created_at asc
limit 1
`)
var item Gateway
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
return item, err
}
func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) {
var bundle wireguard.GatewayBundle
bundle.GatewayID = gatewayID.String()
bundle.Revision = 1
row := r.db.QueryRow(ctx, `
select vpn_cidr::text, listen_port
from gateways
where id = $1 and deleted_at is null
`, gatewayID)
var vpnCIDR string
if err := row.Scan(&vpnCIDR, &bundle.Interface.ListenPort); err != nil {
return wireguard.GatewayBundle{}, err
}
interfaceAddress, err := gatewayInterfaceAddress(vpnCIDR)
if err != nil {
return wireguard.GatewayBundle{}, err
}
bundle.Interface.Address = interfaceAddress
bundle.Interface.NetworkCIDR = vpnCIDR
rows, err := r.db.Query(ctx, `
select
d.id,
wp.public_key,
set_masklen(wp.assigned_ip, 32)::text,
coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'),
coalesce(g.dns_servers, '{}')::text[],
s.value->>'profile_id'
from devices d
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
join gateways g on g.id = d.gateway_id
left join settings s on s.category = 'device_access_profile' and s.key = d.id::text
left join group_memberships gm on gm.user_id = d.user_id
left join policy_targets pt on (
(pt.target_type = 'device' and pt.target_id = d.id) or
(pt.target_type = 'user' and pt.target_id = d.user_id) or
(pt.target_type = 'group' and pt.target_id = gm.group_id)
)
left join policies p on p.id = pt.policy_id
and p.deleted_at is null
and p.is_active = true
and p.effect = 'allow'
left join policy_destinations pd on pd.policy_id = p.id
and (
s.value->>'profile_id' is null
or p.id::text = s.value->>'profile_id'
)
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers, s.value
`, gatewayID)
if err != nil {
return wireguard.GatewayBundle{}, err
}
defer rows.Close()
for rows.Next() {
var peer wireguard.Peer
var deviceID uuid.UUID
var selectedProfileID *string
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &peer.DNSServers, &selectedProfileID); err != nil {
return wireguard.GatewayBundle{}, err
}
peer.DeviceID = deviceID.String()
services, err := r.listAllowedServices(ctx, deviceID, selectedProfileID)
if err != nil {
return wireguard.GatewayBundle{}, err
}
peer.AllowedServices = services
bundle.Peers = append(bundle.Peers, peer)
}
return bundle, rows.Err()
}
func (r *PGRepository) listAllowedServices(ctx context.Context, deviceID uuid.UUID, selectedProfileID *string) ([]wireguard.AllowedService, error) {
rows, err := r.db.Query(ctx, `
select distinct
s.name,
s.domain,
host(s.upstream_ip),
host(s.proxy_ip),
s.ports
from devices d
left join group_memberships gm on gm.user_id = d.user_id
join policy_targets pt on (
(pt.target_type = 'device' and pt.target_id = d.id) or
(pt.target_type = 'user' and pt.target_id = d.user_id) or
(pt.target_type = 'group' and pt.target_id = gm.group_id)
)
join policies p on p.id = pt.policy_id
and p.deleted_at is null
and p.is_active = true
and p.effect = 'allow'
join policy_services ps on ps.policy_id = p.id
join services s on s.id = ps.service_id and s.deleted_at is null and s.is_active = true
where d.id = $1
and ($2::text is null or p.id::text = $2::text)
order by s.name asc
`, deviceID, selectedProfileID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []wireguard.AllowedService
for rows.Next() {
var item wireguard.AllowedService
if err := rows.Scan(&item.Name, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports); err != nil {
return nil, err
}
item.AccessProxyIP = effectiveAccessProxyIP(item.ProxyIP)
items = append(items, item)
}
return items, rows.Err()
}
func effectiveAccessProxyIP(proxyIP string) string {
override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP"))
if override != "" {
return override
}
return proxyIP
}
func (r *PGRepository) ListServiceDNSRecords(ctx context.Context) ([]ServiceDNSRecord, error) {
rows, err := r.db.Query(ctx, `
select distinct
s.domain,
host(s.proxy_ip)
from services s
where s.deleted_at is null and s.is_active = true
order by s.domain asc
`)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ServiceDNSRecord
for rows.Next() {
var item ServiceDNSRecord
var proxyIP string
if err := rows.Scan(&item.Domain, &proxyIP); err != nil {
return nil, err
}
item.TargetIP = effectiveAccessProxyIP(proxyIP)
items = append(items, item)
}
return items, rows.Err()
}
func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) {
row := r.db.QueryRow(ctx, `
update gateways
set endpoint = $2,
public_key = $3,
listen_port = $4,
vpn_cidr = $5::cidr,
dns_servers = $6::text[],
is_active = $7,
updated_at = now()
where id = $1
returning id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
`, gatewayID, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers, input.IsActive)
var item Gateway
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
return item, err
}
func (r *PGRepository) StoreTelemetry(ctx context.Context, gatewayID uuid.UUID, snapshot TelemetrySnapshot) error {
payload, err := json.Marshal(snapshot)
if err != nil {
return err
}
_, err = r.db.Exec(ctx, `
insert into settings (category, key, value, updated_at)
values ('gateway_runtime', $1, $2::jsonb, now())
on conflict (category, key)
do update set value = excluded.value, updated_at = now()
`, gatewayID.String(), string(payload))
return err
}
func (r *PGRepository) UpsertByName(ctx context.Context, input BootstrapRequest) (Gateway, error) {
row := r.db.QueryRow(ctx, `
insert into gateways (id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active)
values ($1, $2, $3, $4, $5, $6::cidr, $7::text[], true)
on conflict (name)
do update set
endpoint = excluded.endpoint,
public_key = excluded.public_key,
listen_port = excluded.listen_port,
vpn_cidr = excluded.vpn_cidr,
dns_servers = excluded.dns_servers,
is_active = true,
updated_at = now()
returning id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
`, uuid.New(), input.Name, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers)
var item Gateway
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
return item, err
}
func gatewayInterfaceAddress(cidr string) (string, error) {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return "", err
}
return prefix.Addr().Next().String() + "/" + intToString(prefix.Bits()), nil
}
func intToString(value int) string {
if value == 0 {
return "0"
}
var digits [20]byte
index := len(digits)
for value > 0 {
index--
digits[index] = byte('0' + value%10)
value /= 10
}
return string(digits[index:])
}