package gateway import ( "context" "net/netip" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "nexavpn/backend/internal/wireguard" ) type Repository interface { List(ctx context.Context) ([]Gateway, error) FirstActive(ctx context.Context) (Gateway, error) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) } type PGRepository struct { db *pgxpool.Pool } func NewPGRepository(db *pgxpool.Pool) *PGRepository { return &PGRepository{db: db} } func (r *PGRepository) List(ctx context.Context) ([]Gateway, error) { rows, err := r.db.Query(ctx, ` select id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active from gateways where deleted_at is null order by created_at desc `) if err != nil { return nil, err } defer rows.Close() var items []Gateway for rows.Next() { var item Gateway if err := rows.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive); err != nil { return nil, err } items = append(items, item) } return items, rows.Err() } func (r *PGRepository) FirstActive(ctx context.Context) (Gateway, error) { row := r.db.QueryRow(ctx, ` select id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active from gateways where deleted_at is null and is_active = true order by created_at asc limit 1 `) var item Gateway err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive) return item, err } func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) { var bundle wireguard.GatewayBundle bundle.GatewayID = gatewayID.String() bundle.Revision = 1 row := r.db.QueryRow(ctx, ` select vpn_cidr::text, listen_port from gateways where id = $1 and deleted_at is null `, gatewayID) var vpnCIDR string if err := row.Scan(&vpnCIDR, &bundle.Interface.ListenPort); err != nil { return wireguard.GatewayBundle{}, err } interfaceAddress, err := gatewayInterfaceAddress(vpnCIDR) if err != nil { return wireguard.GatewayBundle{}, err } bundle.Interface.Address = interfaceAddress bundle.Interface.NetworkCIDR = vpnCIDR rows, err := r.db.Query(ctx, ` select d.id, wp.public_key, set_masklen(wp.assigned_ip, 32)::text, coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}') from devices d join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null left join policy_targets pt on ( (pt.target_type = 'device' and pt.target_id = d.id) or (pt.target_type = 'user' and pt.target_id = d.user_id) ) left join policy_destinations pd on pd.policy_id = pt.policy_id where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active' group by d.id, wp.public_key, wp.assigned_ip `, gatewayID) if err != nil { return wireguard.GatewayBundle{}, err } defer rows.Close() for rows.Next() { var peer wireguard.Peer var deviceID uuid.UUID if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations); err != nil { return wireguard.GatewayBundle{}, err } peer.DeviceID = deviceID.String() bundle.Peers = append(bundle.Peers, peer) } return bundle, rows.Err() } func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) { row := r.db.QueryRow(ctx, ` update gateways set endpoint = $2, public_key = $3, listen_port = $4, vpn_cidr = $5::cidr, dns_servers = $6::text[], is_active = $7, updated_at = now() where id = $1 returning id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active `, gatewayID, input.Endpoint, input.PublicKey, input.ListenPort, input.VPNCIDR, input.DNSServers, input.IsActive) var item Gateway err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive) return item, err } func gatewayInterfaceAddress(cidr string) (string, error) { prefix, err := netip.ParsePrefix(cidr) if err != nil { return "", err } return prefix.Addr().Next().String() + "/" + intToString(prefix.Bits()), nil } func intToString(value int) string { if value == 0 { return "0" } var digits [20]byte index := len(digits) for value > 0 { index-- digits[index] = byte('0' + value%10) value /= 10 } return string(digits[index:]) }