Add mergeProfileAllowedIPs function to combine policy destinations with DNS server routes in device enrollment and rotation. Add dnsServerRoute helper to convert DNS server IPs to /32 CIDR notation. Update BuildSyncBundle query to include gateway DNS servers in peer data. Add DNSServers field to wireguard.Peer struct. Update gateway nftables configuration to allow UDP/TCP port 53 traffic from assigned IPs to DNS servers before
205 lines
6.3 KiB
Go
205 lines
6.3 KiB
Go
package gateway
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/netip"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"nexavpn/backend/internal/wireguard"
|
|
)
|
|
|
|
type Repository interface {
|
|
List(ctx context.Context) ([]Gateway, error)
|
|
FirstActive(ctx context.Context) (Gateway, error)
|
|
BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error)
|
|
StoreTelemetry(ctx context.Context, gatewayID uuid.UUID, snapshot TelemetrySnapshot) error
|
|
Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error)
|
|
UpsertByName(ctx context.Context, input BootstrapRequest) (Gateway, error)
|
|
}
|
|
|
|
type PGRepository struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
|
return &PGRepository{db: db}
|
|
}
|
|
|
|
func (r *PGRepository) List(ctx context.Context) ([]Gateway, error) {
|
|
rows, err := r.db.Query(ctx, `
|
|
select id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
|
|
from gateways
|
|
where deleted_at is null
|
|
order by created_at desc
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []Gateway
|
|
for rows.Next() {
|
|
var item Gateway
|
|
if err := rows.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive); err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (r *PGRepository) FirstActive(ctx context.Context) (Gateway, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
select id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
|
|
from gateways
|
|
where deleted_at is null and is_active = true
|
|
order by created_at asc
|
|
limit 1
|
|
`)
|
|
|
|
var item Gateway
|
|
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
|
|
return item, err
|
|
}
|
|
|
|
func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) {
|
|
var bundle wireguard.GatewayBundle
|
|
bundle.GatewayID = gatewayID.String()
|
|
bundle.Revision = 1
|
|
|
|
row := r.db.QueryRow(ctx, `
|
|
select vpn_cidr::text, listen_port
|
|
from gateways
|
|
where id = $1 and deleted_at is null
|
|
`, gatewayID)
|
|
var vpnCIDR string
|
|
if err := row.Scan(&vpnCIDR, &bundle.Interface.ListenPort); err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
interfaceAddress, err := gatewayInterfaceAddress(vpnCIDR)
|
|
if err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
bundle.Interface.Address = interfaceAddress
|
|
bundle.Interface.NetworkCIDR = vpnCIDR
|
|
|
|
rows, err := r.db.Query(ctx, `
|
|
select
|
|
d.id,
|
|
wp.public_key,
|
|
set_masklen(wp.assigned_ip, 32)::text,
|
|
coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'),
|
|
coalesce(g.dns_servers, '{}')::text[]
|
|
from devices d
|
|
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
|
join gateways g on g.id = d.gateway_id
|
|
left join group_memberships gm on gm.user_id = d.user_id
|
|
left join policy_targets pt on (
|
|
(pt.target_type = 'device' and pt.target_id = d.id) or
|
|
(pt.target_type = 'user' and pt.target_id = d.user_id) or
|
|
(pt.target_type = 'group' and pt.target_id = gm.group_id)
|
|
)
|
|
left join policy_destinations pd on pd.policy_id = pt.policy_id
|
|
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
|
|
group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers
|
|
`, gatewayID)
|
|
if err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var peer wireguard.Peer
|
|
var deviceID uuid.UUID
|
|
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &peer.DNSServers); err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
peer.DeviceID = deviceID.String()
|
|
bundle.Peers = append(bundle.Peers, peer)
|
|
}
|
|
|
|
return bundle, rows.Err()
|
|
}
|
|
|
|
func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
update gateways
|
|
set endpoint = $2,
|
|
public_key = $3,
|
|
listen_port = $4,
|
|
vpn_cidr = $5::cidr,
|
|
dns_servers = $6::text[],
|
|
is_active = $7,
|
|
updated_at = now()
|
|
where id = $1
|
|
returning id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
|
|
`, gatewayID, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers, input.IsActive)
|
|
|
|
var item Gateway
|
|
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
|
|
return item, err
|
|
}
|
|
|
|
func (r *PGRepository) StoreTelemetry(ctx context.Context, gatewayID uuid.UUID, snapshot TelemetrySnapshot) error {
|
|
payload, err := json.Marshal(snapshot)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = r.db.Exec(ctx, `
|
|
insert into settings (category, key, value, updated_at)
|
|
values ('gateway_runtime', $1, $2::jsonb, now())
|
|
on conflict (category, key)
|
|
do update set value = excluded.value, updated_at = now()
|
|
`, gatewayID.String(), string(payload))
|
|
return err
|
|
}
|
|
|
|
func (r *PGRepository) UpsertByName(ctx context.Context, input BootstrapRequest) (Gateway, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
insert into gateways (id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active)
|
|
values ($1, $2, $3, $4, $5, $6::cidr, $7::text[], true)
|
|
on conflict (name)
|
|
do update set
|
|
endpoint = excluded.endpoint,
|
|
public_key = excluded.public_key,
|
|
listen_port = excluded.listen_port,
|
|
vpn_cidr = excluded.vpn_cidr,
|
|
dns_servers = excluded.dns_servers,
|
|
is_active = true,
|
|
updated_at = now()
|
|
returning id, name, endpoint, public_key, listen_port, vpn_cidr::text, dns_servers, is_active
|
|
`, uuid.New(), input.Name, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers)
|
|
|
|
var item Gateway
|
|
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
|
|
return item, err
|
|
}
|
|
|
|
func gatewayInterfaceAddress(cidr string) (string, error) {
|
|
prefix, err := netip.ParsePrefix(cidr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return prefix.Addr().Next().String() + "/" + intToString(prefix.Bits()), nil
|
|
}
|
|
|
|
func intToString(value int) string {
|
|
if value == 0 {
|
|
return "0"
|
|
}
|
|
|
|
var digits [20]byte
|
|
index := len(digits)
|
|
for value > 0 {
|
|
index--
|
|
digits[index] = byte('0' + value%10)
|
|
value /= 10
|
|
}
|
|
return string(digits[index:])
|
|
}
|