Files
NexaVPN/backend/internal/device/repository.go
nessi aaa601a8ba feat: add access profile selection support with device-specific profile persistence
Add SelectOwnProfile handler to allow users to choose from available access profiles. Store selected profile ID per device in settings table with device_access_profile category. Implement GetSelectedProfileID and SetSelectedProfileID repository methods using JSONB storage.

Add ListSelectableProfiles to policy repository and service to query user/group/device-specific profiles ordered by priority. Filter gateway
2026-03-18 12:21:48 +01:00

422 lines
11 KiB
Go

package device
import (
"context"
"encoding/json"
"errors"
"net/netip"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type Repository interface {
FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error)
Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error)
ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error)
ListAll(ctx context.Context) ([]Device, error)
GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error)
GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error)
GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error)
SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error
Revoke(ctx context.Context, deviceID uuid.UUID) error
Rotate(ctx context.Context, deviceID uuid.UUID) error
}
type PGRepository struct {
db *pgxpool.Pool
}
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
return &PGRepository{db: db}
}
func (r *PGRepository) FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error) {
prefix, err := netip.ParsePrefix(vpnCIDR)
if err != nil {
return "", err
}
rows, err := r.db.Query(ctx, `
select host(address)
from ip_allocations
where gateway_id = $1
`, gatewayID)
if err != nil {
return "", err
}
defer rows.Close()
used := map[string]struct{}{}
for rows.Next() {
var address string
if err := rows.Scan(&address); err != nil {
return "", err
}
used[address] = struct{}{}
}
if err := rows.Err(); err != nil {
return "", err
}
address := prefix.Addr().Next()
for i := 1; i < startOffset; i++ {
address = address.Next()
}
for prefix.Contains(address) {
if _, exists := used[address.String()]; !exists {
return address.String() + "/32", nil
}
address = address.Next()
}
return "", errors.New("no available ip addresses for gateway")
}
func (r *PGRepository) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) {
tx, err := r.db.Begin(ctx)
if err != nil {
return EnrollmentResponse{}, err
}
defer tx.Rollback(ctx)
deviceID := uuid.New()
peerID := uuid.New()
_, err = tx.Exec(ctx, `
insert into devices (
id, user_id, gateway_id, name, platform, os_version, app_version, device_fingerprint, public_key, status, approved_at
) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', now())
`, deviceID, userID, gatewayID, input.Name, input.Platform, input.OSVersion, input.AppVersion, input.DeviceFingerprint, input.PublicKey)
if err != nil {
return EnrollmentResponse{}, err
}
_, err = tx.Exec(ctx, `
insert into wireguard_peers (
id, device_id, gateway_id, public_key, assigned_ip, allowed_ips, dns_servers, last_profile_issued_at
) values ($1, $2, $3, $4, $5::inet, $6::cidr[], $7::text[], now())
`, peerID, deviceID, gatewayID, input.PublicKey, assignedIP, allowedIPs, dnsServers)
if err != nil {
return EnrollmentResponse{}, err
}
_, err = tx.Exec(ctx, `
insert into ip_allocations (id, gateway_id, device_id, address, status)
values ($1, $2, $3, $4::inet, 'allocated')
`, uuid.New(), gatewayID, deviceID, assignedIP)
if err != nil {
return EnrollmentResponse{}, err
}
if err := tx.Commit(ctx); err != nil {
return EnrollmentResponse{}, err
}
return EnrollmentResponse{
Device: Device{
ID: deviceID,
UserID: userID,
GatewayID: gatewayID,
Name: input.Name,
Platform: input.Platform,
Status: "active",
AssignedIP: assignedIP,
},
}, nil
}
func (r *PGRepository) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
row := r.db.QueryRow(ctx, `
select
d.id,
d.user_id,
d.gateway_id,
d.name,
d.platform,
d.status,
host(wp.assigned_ip),
wp.profile_revision,
wp.last_profile_issued_at,
g.name,
g.endpoint,
g.public_key,
coalesce(wp.dns_servers, '{}')::text[],
coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[]
from devices d
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
join gateways g on g.id = wp.gateway_id
where d.user_id = $1 and d.deleted_at is null
order by d.created_at desc
limit 1
`, userID)
return scanEnrollmentRow(row)
}
func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
row := r.db.QueryRow(ctx, `
select
d.id,
d.user_id,
d.gateway_id,
d.name,
d.platform,
d.status,
host(wp.assigned_ip),
wp.profile_revision,
wp.last_profile_issued_at,
g.name,
g.endpoint,
g.public_key,
coalesce(wp.dns_servers, '{}')::text[],
coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[]
from devices d
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
join gateways g on g.id = wp.gateway_id
where d.id = $1 and d.deleted_at is null
`, deviceID)
return scanEnrollmentRow(row)
}
func (r *PGRepository) GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error) {
row := r.db.QueryRow(ctx, `
select value->>'profile_id'
from settings
where category = 'device_access_profile' and key = $1
`, deviceID.String())
var raw string
if err := row.Scan(&raw); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, err
}
value, err := uuid.Parse(raw)
if err != nil {
return nil, err
}
return &value, nil
}
func (r *PGRepository) SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error {
payload, err := json.Marshal(map[string]string{"profile_id": profileID.String()})
if err != nil {
return err
}
_, err = r.db.Exec(ctx, `
insert into settings (category, key, value, updated_at)
values ('device_access_profile', $1, $2::jsonb, now())
on conflict (category, key)
do update set value = excluded.value, updated_at = now()
`, deviceID.String(), string(payload))
return err
}
func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
rows, err := r.db.Query(ctx, `
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
from devices d
left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
where d.user_id = $1 and d.deleted_at is null
order by d.created_at desc
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Device
for rows.Next() {
var item Device
if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil {
return nil, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return r.applyRuntimeStats(ctx, items)
}
func (r *PGRepository) ListAll(ctx context.Context) ([]Device, error) {
rows, err := r.db.Query(ctx, `
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
from devices d
left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
where d.deleted_at is null
order by d.created_at desc
`)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Device
for rows.Next() {
var item Device
if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil {
return nil, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return r.applyRuntimeStats(ctx, items)
}
func (r *PGRepository) Revoke(ctx context.Context, deviceID uuid.UUID) error {
tx, err := r.db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
if _, err := tx.Exec(ctx, `update devices set status = 'revoked', revoked_at = now(), updated_at = now() where id = $1`, deviceID); err != nil {
return err
}
if _, err := tx.Exec(ctx, `update wireguard_peers set deleted_at = now(), updated_at = now() where device_id = $1 and deleted_at is null`, deviceID); err != nil {
return err
}
if _, err := tx.Exec(ctx, `update ip_allocations set status = 'released', released_at = now(), updated_at = now() where device_id = $1 and status = 'allocated'`, deviceID); err != nil {
return err
}
return tx.Commit(ctx)
}
func (r *PGRepository) Rotate(ctx context.Context, deviceID uuid.UUID) error {
_, err := r.db.Exec(ctx, `
update wireguard_peers
set profile_revision = profile_revision + 1, last_profile_issued_at = now(), updated_at = now()
where device_id = $1 and deleted_at is null
`, deviceID)
return err
}
type enrollmentRowScanner interface {
Scan(dest ...any) error
}
type runtimeSnapshot struct {
Peers []runtimePeer `json:"peers"`
}
type runtimePeer struct {
DeviceID string `json:"device_id"`
RXBytes uint64 `json:"rx_bytes"`
TXBytes uint64 `json:"tx_bytes"`
}
func (r *PGRepository) applyRuntimeStats(ctx context.Context, items []Device) ([]Device, error) {
if len(items) == 0 {
return items, nil
}
rows, err := r.db.Query(ctx, `
select value
from settings
where category = 'gateway_runtime'
`)
if err != nil {
return nil, err
}
defer rows.Close()
statsByDevice := make(map[uuid.UUID]runtimePeer)
for rows.Next() {
var raw []byte
if err := rows.Scan(&raw); err != nil {
return nil, err
}
var snapshot runtimeSnapshot
if err := json.Unmarshal(raw, &snapshot); err != nil {
continue
}
for _, peer := range snapshot.Peers {
deviceID, err := uuid.Parse(peer.DeviceID)
if err != nil {
continue
}
existing := statsByDevice[deviceID]
existing.DeviceID = peer.DeviceID
existing.RXBytes += peer.RXBytes
existing.TXBytes += peer.TXBytes
statsByDevice[deviceID] = existing
}
}
if err := rows.Err(); err != nil {
return nil, err
}
for index := range items {
if stats, ok := statsByDevice[items[index].ID]; ok {
items[index].RXBytes = stats.RXBytes
items[index].TXBytes = stats.TXBytes
}
}
return items, nil
}
func scanEnrollmentRow(row enrollmentRowScanner) (EnrollmentResponse, error) {
var response EnrollmentResponse
var profileRevision int
var lastIssuedAt *time.Time
var gatewayName string
var gatewayEndpoint string
var gatewayPublicKey string
var dnsServers []string
var allowedIPs []string
if err := row.Scan(
&response.Device.ID,
&response.Device.UserID,
&response.Device.GatewayID,
&response.Device.Name,
&response.Device.Platform,
&response.Device.Status,
&response.Device.AssignedIP,
&profileRevision,
&lastIssuedAt,
&gatewayName,
&gatewayEndpoint,
&gatewayPublicKey,
&dnsServers,
&allowedIPs,
); err != nil {
return EnrollmentResponse{}, err
}
response.Peer = PeerView{
AssignedIP: response.Device.AssignedIP,
DNSServers: dnsServers,
AllowedIPs: allowedIPs,
Gateway: GatewayView{
ID: response.Device.GatewayID,
Name: gatewayName,
Endpoint: gatewayEndpoint,
PublicKey: gatewayPublicKey,
},
ProfileRevision: profileRevision,
}
for _, destination := range allowedIPs {
response.Resources = append(response.Resources, Resource{
Type: "cidr",
Value: destination,
Label: destination,
})
}
_ = lastIssuedAt
return response, nil
}