package device import ( "context" "encoding/json" "errors" "net/netip" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) type Repository interface { FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) ListAll(ctx context.Context) ([]Device, error) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error) SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error Revoke(ctx context.Context, deviceID uuid.UUID) error Delete(ctx context.Context, deviceID uuid.UUID) error Rotate(ctx context.Context, deviceID uuid.UUID) error } type PGRepository struct { db *pgxpool.Pool } func NewPGRepository(db *pgxpool.Pool) *PGRepository { return &PGRepository{db: db} } func (r *PGRepository) FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error) { prefix, err := netip.ParsePrefix(vpnCIDR) if err != nil { return "", err } rows, err := r.db.Query(ctx, ` select host(address) from ip_allocations where gateway_id = $1 `, gatewayID) if err != nil { return "", err } defer rows.Close() used := map[string]struct{}{} for rows.Next() { var address string if err := rows.Scan(&address); err != nil { return "", err } used[address] = struct{}{} } if err := rows.Err(); err != nil { return "", err } address := prefix.Addr().Next() for i := 1; i < startOffset; i++ { address = address.Next() } for prefix.Contains(address) { if _, exists := used[address.String()]; !exists { return address.String() + "/32", nil } address = address.Next() } return "", errors.New("no available ip addresses for gateway") } func (r *PGRepository) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) { tx, err := r.db.Begin(ctx) if err != nil { return EnrollmentResponse{}, err } defer tx.Rollback(ctx) deviceID := uuid.New() peerID := uuid.New() _, err = tx.Exec(ctx, ` insert into devices ( id, user_id, gateway_id, name, platform, os_version, app_version, device_fingerprint, public_key, status, approved_at ) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', now()) `, deviceID, userID, gatewayID, input.Name, input.Platform, input.OSVersion, input.AppVersion, input.DeviceFingerprint, input.PublicKey) if err != nil { return EnrollmentResponse{}, err } _, err = tx.Exec(ctx, ` insert into wireguard_peers ( id, device_id, gateway_id, public_key, assigned_ip, allowed_ips, dns_servers, last_profile_issued_at ) values ($1, $2, $3, $4, $5::inet, $6::cidr[], $7::text[], now()) `, peerID, deviceID, gatewayID, input.PublicKey, assignedIP, allowedIPs, dnsServers) if err != nil { return EnrollmentResponse{}, err } _, err = tx.Exec(ctx, ` insert into ip_allocations (id, gateway_id, device_id, address, status) values ($1, $2, $3, $4::inet, 'allocated') `, uuid.New(), gatewayID, deviceID, assignedIP) if err != nil { return EnrollmentResponse{}, err } if err := tx.Commit(ctx); err != nil { return EnrollmentResponse{}, err } return EnrollmentResponse{ Device: Device{ ID: deviceID, UserID: userID, GatewayID: gatewayID, Name: input.Name, Platform: input.Platform, Status: "active", AssignedIP: assignedIP, }, }, nil } func (r *PGRepository) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) { row := r.db.QueryRow(ctx, ` select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, host(wp.assigned_ip), wp.profile_revision, wp.last_profile_issued_at, g.name, g.endpoint, g.public_key, coalesce(wp.dns_servers, '{}')::text[], coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[] from devices d join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null join gateways g on g.id = wp.gateway_id where d.user_id = $1 and d.deleted_at is null order by d.created_at desc limit 1 `, userID) return scanEnrollmentRow(row) } func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) { row := r.db.QueryRow(ctx, ` select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, host(wp.assigned_ip), wp.profile_revision, wp.last_profile_issued_at, g.name, g.endpoint, g.public_key, coalesce(wp.dns_servers, '{}')::text[], coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[] from devices d join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null join gateways g on g.id = wp.gateway_id where d.id = $1 and d.deleted_at is null `, deviceID) return scanEnrollmentRow(row) } func (r *PGRepository) GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error) { row := r.db.QueryRow(ctx, ` select value->>'profile_id' from settings where category = 'device_access_profile' and key = $1 `, deviceID.String()) var raw string if err := row.Scan(&raw); err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return nil, err } value, err := uuid.Parse(raw) if err != nil { return nil, err } return &value, nil } func (r *PGRepository) SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error { payload, err := json.Marshal(map[string]string{"profile_id": profileID.String()}) if err != nil { return err } _, err = r.db.Exec(ctx, ` insert into settings (category, key, value, updated_at) values ('device_access_profile', $1, $2::jsonb, now()) on conflict (category, key) do update set value = excluded.value, updated_at = now() `, deviceID.String(), string(payload)) return err } func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) { rows, err := r.db.Query(ctx, ` select d.id, d.user_id, coalesce(u.username, ''), d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '') from devices d left join users u on u.id = d.user_id left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null where d.user_id = $1 and d.deleted_at is null order by d.created_at desc `, userID) if err != nil { return nil, err } defer rows.Close() var items []Device for rows.Next() { var item Device if err := rows.Scan(&item.ID, &item.UserID, &item.OwnerUsername, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil { return nil, err } items = append(items, item) } if err := rows.Err(); err != nil { return nil, err } return r.applyRuntimeStats(ctx, items) } func (r *PGRepository) ListAll(ctx context.Context) ([]Device, error) { rows, err := r.db.Query(ctx, ` select d.id, d.user_id, coalesce(u.username, ''), d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '') from devices d left join users u on u.id = d.user_id left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null where d.deleted_at is null order by d.created_at desc `) if err != nil { return nil, err } defer rows.Close() var items []Device for rows.Next() { var item Device if err := rows.Scan(&item.ID, &item.UserID, &item.OwnerUsername, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil { return nil, err } items = append(items, item) } if err := rows.Err(); err != nil { return nil, err } return r.applyRuntimeStats(ctx, items) } func (r *PGRepository) Revoke(ctx context.Context, deviceID uuid.UUID) error { tx, err := r.db.Begin(ctx) if err != nil { return err } defer tx.Rollback(ctx) if _, err := tx.Exec(ctx, `update devices set status = 'revoked', revoked_at = now(), updated_at = now() where id = $1`, deviceID); err != nil { return err } if _, err := tx.Exec(ctx, `update wireguard_peers set deleted_at = now(), updated_at = now() where device_id = $1 and deleted_at is null`, deviceID); err != nil { return err } if _, err := tx.Exec(ctx, `update ip_allocations set status = 'released', released_at = now(), updated_at = now() where device_id = $1 and status = 'allocated'`, deviceID); err != nil { return err } return tx.Commit(ctx) } func (r *PGRepository) Delete(ctx context.Context, deviceID uuid.UUID) error { tx, err := r.db.Begin(ctx) if err != nil { return err } defer tx.Rollback(ctx) if _, err := tx.Exec(ctx, ` update devices set status = 'deleted', deleted_at = now(), revoked_at = coalesce(revoked_at, now()), updated_at = now() where id = $1 and deleted_at is null `, deviceID); err != nil { return err } if _, err := tx.Exec(ctx, ` update wireguard_peers set deleted_at = now(), updated_at = now() where device_id = $1 and deleted_at is null `, deviceID); err != nil { return err } if _, err := tx.Exec(ctx, ` update ip_allocations set status = 'released', released_at = coalesce(released_at, now()), updated_at = now() where device_id = $1 and status = 'allocated' `, deviceID); err != nil { return err } if _, err := tx.Exec(ctx, ` delete from settings where category = 'device_access_profile' and key = $1 `, deviceID.String()); err != nil { return err } return tx.Commit(ctx) } func (r *PGRepository) Rotate(ctx context.Context, deviceID uuid.UUID) error { _, err := r.db.Exec(ctx, ` update wireguard_peers set profile_revision = profile_revision + 1, last_profile_issued_at = now(), updated_at = now() where device_id = $1 and deleted_at is null `, deviceID) return err } type enrollmentRowScanner interface { Scan(dest ...any) error } type runtimeSnapshot struct { Peers []runtimePeer `json:"peers"` } type runtimePeer struct { DeviceID string `json:"device_id"` RXBytes uint64 `json:"rx_bytes"` TXBytes uint64 `json:"tx_bytes"` LatestHandshakeAt *int64 `json:"latest_handshake_at,omitempty"` } func (r *PGRepository) applyRuntimeStats(ctx context.Context, items []Device) ([]Device, error) { if len(items) == 0 { return items, nil } rows, err := r.db.Query(ctx, ` select value from settings where category = 'gateway_runtime' `) if err != nil { return nil, err } defer rows.Close() statsByDevice := make(map[uuid.UUID]runtimePeer) for rows.Next() { var raw []byte if err := rows.Scan(&raw); err != nil { return nil, err } var snapshot runtimeSnapshot if err := json.Unmarshal(raw, &snapshot); err != nil { continue } for _, peer := range snapshot.Peers { deviceID, err := uuid.Parse(peer.DeviceID) if err != nil { continue } existing := statsByDevice[deviceID] existing.DeviceID = peer.DeviceID existing.RXBytes += peer.RXBytes existing.TXBytes += peer.TXBytes if peer.LatestHandshakeAt != nil { if existing.LatestHandshakeAt == nil || *peer.LatestHandshakeAt > *existing.LatestHandshakeAt { existing.LatestHandshakeAt = peer.LatestHandshakeAt } } statsByDevice[deviceID] = existing } } if err := rows.Err(); err != nil { return nil, err } for index := range items { if stats, ok := statsByDevice[items[index].ID]; ok { items[index].RXBytes = stats.RXBytes items[index].TXBytes = stats.TXBytes if isPeerConnected(stats.LatestHandshakeAt) { items[index].ConnectionStatus = "connected" } else { items[index].ConnectionStatus = "disconnected" } } else { items[index].ConnectionStatus = "disconnected" } } return items, nil } func isPeerConnected(latestHandshakeAt *int64) bool { if latestHandshakeAt == nil || *latestHandshakeAt <= 0 { return false } handshakeTime := time.Unix(*latestHandshakeAt, 0) return time.Since(handshakeTime) <= 3*time.Minute } func scanEnrollmentRow(row enrollmentRowScanner) (EnrollmentResponse, error) { var response EnrollmentResponse var profileRevision int var lastIssuedAt *time.Time var gatewayName string var gatewayEndpoint string var gatewayPublicKey string var dnsServers []string var allowedIPs []string if err := row.Scan( &response.Device.ID, &response.Device.UserID, &response.Device.GatewayID, &response.Device.Name, &response.Device.Platform, &response.Device.Status, &response.Device.AssignedIP, &profileRevision, &lastIssuedAt, &gatewayName, &gatewayEndpoint, &gatewayPublicKey, &dnsServers, &allowedIPs, ); err != nil { return EnrollmentResponse{}, err } response.Peer = PeerView{ AssignedIP: response.Device.AssignedIP, DNSServers: dnsServers, AllowedIPs: allowedIPs, Gateway: GatewayView{ ID: response.Device.GatewayID, Name: gatewayName, Endpoint: gatewayEndpoint, PublicKey: gatewayPublicKey, }, ProfileRevision: profileRevision, } for _, destination := range allowedIPs { response.Resources = append(response.Resources, Resource{ Type: "cidr", Value: destination, Label: destination, }) } _ = lastIssuedAt return response, nil }