Update go.mod module declaration and all internal imports across the backend codebase to use simplified nexavpn/backend path instead of full GitHub URL.
103 lines
3.0 KiB
Go
103 lines
3.0 KiB
Go
package gateway
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"nexavpn/backend/internal/wireguard"
|
|
)
|
|
|
|
type Repository interface {
|
|
List(ctx context.Context) ([]Gateway, error)
|
|
FirstActive(ctx context.Context) (Gateway, error)
|
|
BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error)
|
|
}
|
|
|
|
type PGRepository struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
|
return &PGRepository{db: db}
|
|
}
|
|
|
|
func (r *PGRepository) List(ctx context.Context) ([]Gateway, error) {
|
|
rows, err := r.db.Query(ctx, `
|
|
select id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active
|
|
from gateways
|
|
where deleted_at is null
|
|
order by created_at desc
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []Gateway
|
|
for rows.Next() {
|
|
var item Gateway
|
|
if err := rows.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive); err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (r *PGRepository) FirstActive(ctx context.Context) (Gateway, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
select id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active
|
|
from gateways
|
|
where deleted_at is null and is_active = true
|
|
order by created_at asc
|
|
limit 1
|
|
`)
|
|
|
|
var item Gateway
|
|
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
|
|
return item, err
|
|
}
|
|
|
|
func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) {
|
|
var bundle wireguard.GatewayBundle
|
|
bundle.GatewayID = gatewayID.String()
|
|
bundle.Revision = 1
|
|
|
|
row := r.db.QueryRow(ctx, `
|
|
select host(vpn_cidr), listen_port
|
|
from gateways
|
|
where id = $1 and deleted_at is null
|
|
`, gatewayID)
|
|
if err := row.Scan(&bundle.Interface.Address, &bundle.Interface.ListenPort); err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
|
|
rows, err := r.db.Query(ctx, `
|
|
select d.id, wp.public_key, host(wp.assigned_ip), coalesce(array_agg(pd.destination::text) filter (where pd.destination is not null), '{}')
|
|
from devices d
|
|
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
|
left join policy_targets pt on pt.target_id = d.id and pt.target_type = 'device'
|
|
left join policy_destinations pd on pd.policy_id = pt.policy_id
|
|
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
|
|
group by d.id, wp.public_key, wp.assigned_ip
|
|
`, gatewayID)
|
|
if err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var peer wireguard.Peer
|
|
var deviceID uuid.UUID
|
|
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations); err != nil {
|
|
return wireguard.GatewayBundle{}, err
|
|
}
|
|
peer.DeviceID = deviceID.String()
|
|
bundle.Peers = append(bundle.Peers, peer)
|
|
}
|
|
|
|
return bundle, rows.Err()
|
|
}
|