package device import ( "context" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) type Repository interface { Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) ListAll(ctx context.Context) ([]Device, error) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) Revoke(ctx context.Context, deviceID uuid.UUID) error Rotate(ctx context.Context, deviceID uuid.UUID) error } type PGRepository struct { db *pgxpool.Pool } func NewPGRepository(db *pgxpool.Pool) *PGRepository { return &PGRepository{db: db} } func (r *PGRepository) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) { tx, err := r.db.Begin(ctx) if err != nil { return EnrollmentResponse{}, err } defer tx.Rollback(ctx) deviceID := uuid.New() peerID := uuid.New() _, err = tx.Exec(ctx, ` insert into devices ( id, user_id, gateway_id, name, platform, os_version, app_version, device_fingerprint, public_key, status, approved_at ) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', now()) `, deviceID, userID, gatewayID, input.Name, input.Platform, input.OSVersion, input.AppVersion, input.DeviceFingerprint, input.PublicKey) if err != nil { return EnrollmentResponse{}, err } _, err = tx.Exec(ctx, ` insert into wireguard_peers ( id, device_id, gateway_id, public_key, assigned_ip, allowed_ips, dns_servers, last_profile_issued_at ) values ($1, $2, $3, $4, $5::inet, $6::cidr[], $7::text[], now()) `, peerID, deviceID, gatewayID, input.PublicKey, assignedIP, allowedIPs, dnsServers) if err != nil { return EnrollmentResponse{}, err } _, err = tx.Exec(ctx, ` insert into ip_allocations (id, gateway_id, device_id, address, status) values ($1, $2, $3, $4::inet, 'allocated') `, uuid.New(), gatewayID, deviceID, assignedIP) if err != nil { return EnrollmentResponse{}, err } if err := tx.Commit(ctx); err != nil { return EnrollmentResponse{}, err } return EnrollmentResponse{ Device: Device{ ID: deviceID, UserID: userID, GatewayID: gatewayID, Name: input.Name, Platform: input.Platform, Status: "active", AssignedIP: assignedIP, }, }, nil } func (r *PGRepository) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) { row := r.db.QueryRow(ctx, ` select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, host(wp.assigned_ip), wp.profile_revision, wp.last_profile_issued_at, g.name, g.endpoint, g.public_key, coalesce(wp.dns_servers, '{}')::text[], coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[] from devices d join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null join gateways g on g.id = wp.gateway_id where d.user_id = $1 and d.deleted_at is null order by d.created_at desc limit 1 `, userID) return scanEnrollmentRow(row) } func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) { row := r.db.QueryRow(ctx, ` select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, host(wp.assigned_ip), wp.profile_revision, wp.last_profile_issued_at, g.name, g.endpoint, g.public_key, coalesce(wp.dns_servers, '{}')::text[], coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[] from devices d join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null join gateways g on g.id = wp.gateway_id where d.id = $1 and d.deleted_at is null `, deviceID) return scanEnrollmentRow(row) } func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) { rows, err := r.db.Query(ctx, ` select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '') from devices d left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null where d.user_id = $1 and d.deleted_at is null order by d.created_at desc `, userID) if err != nil { return nil, err } defer rows.Close() var items []Device for rows.Next() { var item Device if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil { return nil, err } items = append(items, item) } return items, rows.Err() } func (r *PGRepository) ListAll(ctx context.Context) ([]Device, error) { rows, err := r.db.Query(ctx, ` select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '') from devices d left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null where d.deleted_at is null order by d.created_at desc `) if err != nil { return nil, err } defer rows.Close() var items []Device for rows.Next() { var item Device if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil { return nil, err } items = append(items, item) } return items, rows.Err() } func (r *PGRepository) Revoke(ctx context.Context, deviceID uuid.UUID) error { tx, err := r.db.Begin(ctx) if err != nil { return err } defer tx.Rollback(ctx) if _, err := tx.Exec(ctx, `update devices set status = 'revoked', revoked_at = now(), updated_at = now() where id = $1`, deviceID); err != nil { return err } if _, err := tx.Exec(ctx, `update wireguard_peers set deleted_at = now(), updated_at = now() where device_id = $1 and deleted_at is null`, deviceID); err != nil { return err } if _, err := tx.Exec(ctx, `update ip_allocations set status = 'released', released_at = now(), updated_at = now() where device_id = $1 and status = 'allocated'`, deviceID); err != nil { return err } return tx.Commit(ctx) } func (r *PGRepository) Rotate(ctx context.Context, deviceID uuid.UUID) error { _, err := r.db.Exec(ctx, ` update wireguard_peers set profile_revision = profile_revision + 1, last_profile_issued_at = now(), updated_at = now() where device_id = $1 and deleted_at is null `, deviceID) return err } type enrollmentRowScanner interface { Scan(dest ...any) error } func scanEnrollmentRow(row enrollmentRowScanner) (EnrollmentResponse, error) { var response EnrollmentResponse var profileRevision int var lastIssuedAt *time.Time var gatewayName string var gatewayEndpoint string var gatewayPublicKey string var dnsServers []string var allowedIPs []string if err := row.Scan( &response.Device.ID, &response.Device.UserID, &response.Device.GatewayID, &response.Device.Name, &response.Device.Platform, &response.Device.Status, &response.Device.AssignedIP, &profileRevision, &lastIssuedAt, &gatewayName, &gatewayEndpoint, &gatewayPublicKey, &dnsServers, &allowedIPs, ); err != nil { return EnrollmentResponse{}, err } response.Peer = PeerView{ AssignedIP: response.Device.AssignedIP, DNSServers: dnsServers, AllowedIPs: allowedIPs, Gateway: GatewayView{ ID: response.Device.GatewayID, Name: gatewayName, Endpoint: gatewayEndpoint, PublicKey: gatewayPublicKey, }, ProfileRevision: profileRevision, } for _, destination := range allowedIPs { response.Resources = append(response.Resources, Resource{ Type: "cidr", Value: destination, Label: destination, }) } _ = lastIssuedAt return response, nil }