Files
NexaVPN/backend/internal/device/repository.go
nessi b16564ac5c feat: add database-backed IP allocation with fallback to IPAM service
Add FindNextAvailableIP repository method to query ip_allocations table and find next available IP address within gateway VPN CIDR range. Query existing allocations from database and build used IP map. Iterate through CIDR range starting at offset to find first unused address. Update Enroll service method to call FindNextAvailableIP first with fallback to IPAM service Allocate method on error. Add netip and errors imports to repository
2026-03-17 21:43:42 +01:00

312 lines
8.7 KiB
Go

package device
import (
"context"
"errors"
"net/netip"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
type Repository interface {
FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error)
Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error)
ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error)
ListAll(ctx context.Context) ([]Device, error)
GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error)
GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error)
Revoke(ctx context.Context, deviceID uuid.UUID) error
Rotate(ctx context.Context, deviceID uuid.UUID) error
}
type PGRepository struct {
db *pgxpool.Pool
}
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
return &PGRepository{db: db}
}
func (r *PGRepository) FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error) {
prefix, err := netip.ParsePrefix(vpnCIDR)
if err != nil {
return "", err
}
rows, err := r.db.Query(ctx, `
select host(address)
from ip_allocations
where gateway_id = $1
`, gatewayID)
if err != nil {
return "", err
}
defer rows.Close()
used := map[string]struct{}{}
for rows.Next() {
var address string
if err := rows.Scan(&address); err != nil {
return "", err
}
used[address] = struct{}{}
}
if err := rows.Err(); err != nil {
return "", err
}
address := prefix.Addr().Next()
for i := 1; i < startOffset; i++ {
address = address.Next()
}
for prefix.Contains(address) {
if _, exists := used[address.String()]; !exists {
return address.String() + "/32", nil
}
address = address.Next()
}
return "", errors.New("no available ip addresses for gateway")
}
func (r *PGRepository) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) {
tx, err := r.db.Begin(ctx)
if err != nil {
return EnrollmentResponse{}, err
}
defer tx.Rollback(ctx)
deviceID := uuid.New()
peerID := uuid.New()
_, err = tx.Exec(ctx, `
insert into devices (
id, user_id, gateway_id, name, platform, os_version, app_version, device_fingerprint, public_key, status, approved_at
) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', now())
`, deviceID, userID, gatewayID, input.Name, input.Platform, input.OSVersion, input.AppVersion, input.DeviceFingerprint, input.PublicKey)
if err != nil {
return EnrollmentResponse{}, err
}
_, err = tx.Exec(ctx, `
insert into wireguard_peers (
id, device_id, gateway_id, public_key, assigned_ip, allowed_ips, dns_servers, last_profile_issued_at
) values ($1, $2, $3, $4, $5::inet, $6::cidr[], $7::text[], now())
`, peerID, deviceID, gatewayID, input.PublicKey, assignedIP, allowedIPs, dnsServers)
if err != nil {
return EnrollmentResponse{}, err
}
_, err = tx.Exec(ctx, `
insert into ip_allocations (id, gateway_id, device_id, address, status)
values ($1, $2, $3, $4::inet, 'allocated')
`, uuid.New(), gatewayID, deviceID, assignedIP)
if err != nil {
return EnrollmentResponse{}, err
}
if err := tx.Commit(ctx); err != nil {
return EnrollmentResponse{}, err
}
return EnrollmentResponse{
Device: Device{
ID: deviceID,
UserID: userID,
GatewayID: gatewayID,
Name: input.Name,
Platform: input.Platform,
Status: "active",
AssignedIP: assignedIP,
},
}, nil
}
func (r *PGRepository) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
row := r.db.QueryRow(ctx, `
select
d.id,
d.user_id,
d.gateway_id,
d.name,
d.platform,
d.status,
host(wp.assigned_ip),
wp.profile_revision,
wp.last_profile_issued_at,
g.name,
g.endpoint,
g.public_key,
coalesce(wp.dns_servers, '{}')::text[],
coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[]
from devices d
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
join gateways g on g.id = wp.gateway_id
where d.user_id = $1 and d.deleted_at is null
order by d.created_at desc
limit 1
`, userID)
return scanEnrollmentRow(row)
}
func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
row := r.db.QueryRow(ctx, `
select
d.id,
d.user_id,
d.gateway_id,
d.name,
d.platform,
d.status,
host(wp.assigned_ip),
wp.profile_revision,
wp.last_profile_issued_at,
g.name,
g.endpoint,
g.public_key,
coalesce(wp.dns_servers, '{}')::text[],
coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[]
from devices d
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
join gateways g on g.id = wp.gateway_id
where d.id = $1 and d.deleted_at is null
`, deviceID)
return scanEnrollmentRow(row)
}
func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
rows, err := r.db.Query(ctx, `
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
from devices d
left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
where d.user_id = $1 and d.deleted_at is null
order by d.created_at desc
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Device
for rows.Next() {
var item Device
if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (r *PGRepository) ListAll(ctx context.Context) ([]Device, error) {
rows, err := r.db.Query(ctx, `
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
from devices d
left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
where d.deleted_at is null
order by d.created_at desc
`)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Device
for rows.Next() {
var item Device
if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (r *PGRepository) Revoke(ctx context.Context, deviceID uuid.UUID) error {
tx, err := r.db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
if _, err := tx.Exec(ctx, `update devices set status = 'revoked', revoked_at = now(), updated_at = now() where id = $1`, deviceID); err != nil {
return err
}
if _, err := tx.Exec(ctx, `update wireguard_peers set deleted_at = now(), updated_at = now() where device_id = $1 and deleted_at is null`, deviceID); err != nil {
return err
}
if _, err := tx.Exec(ctx, `update ip_allocations set status = 'released', released_at = now(), updated_at = now() where device_id = $1 and status = 'allocated'`, deviceID); err != nil {
return err
}
return tx.Commit(ctx)
}
func (r *PGRepository) Rotate(ctx context.Context, deviceID uuid.UUID) error {
_, err := r.db.Exec(ctx, `
update wireguard_peers
set profile_revision = profile_revision + 1, last_profile_issued_at = now(), updated_at = now()
where device_id = $1 and deleted_at is null
`, deviceID)
return err
}
type enrollmentRowScanner interface {
Scan(dest ...any) error
}
func scanEnrollmentRow(row enrollmentRowScanner) (EnrollmentResponse, error) {
var response EnrollmentResponse
var profileRevision int
var lastIssuedAt *time.Time
var gatewayName string
var gatewayEndpoint string
var gatewayPublicKey string
var dnsServers []string
var allowedIPs []string
if err := row.Scan(
&response.Device.ID,
&response.Device.UserID,
&response.Device.GatewayID,
&response.Device.Name,
&response.Device.Platform,
&response.Device.Status,
&response.Device.AssignedIP,
&profileRevision,
&lastIssuedAt,
&gatewayName,
&gatewayEndpoint,
&gatewayPublicKey,
&dnsServers,
&allowedIPs,
); err != nil {
return EnrollmentResponse{}, err
}
response.Peer = PeerView{
AssignedIP: response.Device.AssignedIP,
DNSServers: dnsServers,
AllowedIPs: allowedIPs,
Gateway: GatewayView{
ID: response.Device.GatewayID,
Name: gatewayName,
Endpoint: gatewayEndpoint,
PublicKey: gatewayPublicKey,
},
ProfileRevision: profileRevision,
}
for _, destination := range allowedIPs {
response.Resources = append(response.Resources, Resource{
Type: "cidr",
Value: destination,
Label: destination,
})
}
_ = lastIssuedAt
return response, nil
}