Add SelectOwnProfile handler to allow users to choose from available access profiles. Store selected profile ID per device in settings table with device_access_profile category. Implement GetSelectedProfileID and SetSelectedProfileID repository methods using JSONB storage. Add ListSelectableProfiles to policy repository and service to query user/group/device-specific profiles ordered by priority. Filter gateway
422 lines
11 KiB
Go
422 lines
11 KiB
Go
package device
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/netip"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type Repository interface {
|
|
FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error)
|
|
Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error)
|
|
ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error)
|
|
ListAll(ctx context.Context) ([]Device, error)
|
|
GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error)
|
|
GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error)
|
|
GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error)
|
|
SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error
|
|
Revoke(ctx context.Context, deviceID uuid.UUID) error
|
|
Rotate(ctx context.Context, deviceID uuid.UUID) error
|
|
}
|
|
|
|
type PGRepository struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
|
return &PGRepository{db: db}
|
|
}
|
|
|
|
func (r *PGRepository) FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error) {
|
|
prefix, err := netip.ParsePrefix(vpnCIDR)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
rows, err := r.db.Query(ctx, `
|
|
select host(address)
|
|
from ip_allocations
|
|
where gateway_id = $1
|
|
`, gatewayID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer rows.Close()
|
|
|
|
used := map[string]struct{}{}
|
|
for rows.Next() {
|
|
var address string
|
|
if err := rows.Scan(&address); err != nil {
|
|
return "", err
|
|
}
|
|
used[address] = struct{}{}
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
address := prefix.Addr().Next()
|
|
for i := 1; i < startOffset; i++ {
|
|
address = address.Next()
|
|
}
|
|
|
|
for prefix.Contains(address) {
|
|
if _, exists := used[address.String()]; !exists {
|
|
return address.String() + "/32", nil
|
|
}
|
|
address = address.Next()
|
|
}
|
|
|
|
return "", errors.New("no available ip addresses for gateway")
|
|
}
|
|
|
|
func (r *PGRepository) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) {
|
|
tx, err := r.db.Begin(ctx)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
deviceID := uuid.New()
|
|
peerID := uuid.New()
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
insert into devices (
|
|
id, user_id, gateway_id, name, platform, os_version, app_version, device_fingerprint, public_key, status, approved_at
|
|
) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', now())
|
|
`, deviceID, userID, gatewayID, input.Name, input.Platform, input.OSVersion, input.AppVersion, input.DeviceFingerprint, input.PublicKey)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
insert into wireguard_peers (
|
|
id, device_id, gateway_id, public_key, assigned_ip, allowed_ips, dns_servers, last_profile_issued_at
|
|
) values ($1, $2, $3, $4, $5::inet, $6::cidr[], $7::text[], now())
|
|
`, peerID, deviceID, gatewayID, input.PublicKey, assignedIP, allowedIPs, dnsServers)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
insert into ip_allocations (id, gateway_id, device_id, address, status)
|
|
values ($1, $2, $3, $4::inet, 'allocated')
|
|
`, uuid.New(), gatewayID, deviceID, assignedIP)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
return EnrollmentResponse{
|
|
Device: Device{
|
|
ID: deviceID,
|
|
UserID: userID,
|
|
GatewayID: gatewayID,
|
|
Name: input.Name,
|
|
Platform: input.Platform,
|
|
Status: "active",
|
|
AssignedIP: assignedIP,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (r *PGRepository) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
select
|
|
d.id,
|
|
d.user_id,
|
|
d.gateway_id,
|
|
d.name,
|
|
d.platform,
|
|
d.status,
|
|
host(wp.assigned_ip),
|
|
wp.profile_revision,
|
|
wp.last_profile_issued_at,
|
|
g.name,
|
|
g.endpoint,
|
|
g.public_key,
|
|
coalesce(wp.dns_servers, '{}')::text[],
|
|
coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[]
|
|
from devices d
|
|
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
|
join gateways g on g.id = wp.gateway_id
|
|
where d.user_id = $1 and d.deleted_at is null
|
|
order by d.created_at desc
|
|
limit 1
|
|
`, userID)
|
|
return scanEnrollmentRow(row)
|
|
}
|
|
|
|
func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
select
|
|
d.id,
|
|
d.user_id,
|
|
d.gateway_id,
|
|
d.name,
|
|
d.platform,
|
|
d.status,
|
|
host(wp.assigned_ip),
|
|
wp.profile_revision,
|
|
wp.last_profile_issued_at,
|
|
g.name,
|
|
g.endpoint,
|
|
g.public_key,
|
|
coalesce(wp.dns_servers, '{}')::text[],
|
|
coalesce(array(select cidr::text from unnest(wp.allowed_ips) as cidr), '{}')::text[]
|
|
from devices d
|
|
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
|
join gateways g on g.id = wp.gateway_id
|
|
where d.id = $1 and d.deleted_at is null
|
|
`, deviceID)
|
|
return scanEnrollmentRow(row)
|
|
}
|
|
|
|
func (r *PGRepository) GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
select value->>'profile_id'
|
|
from settings
|
|
where category = 'device_access_profile' and key = $1
|
|
`, deviceID.String())
|
|
|
|
var raw string
|
|
if err := row.Scan(&raw); err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
value, err := uuid.Parse(raw)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &value, nil
|
|
}
|
|
|
|
func (r *PGRepository) SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error {
|
|
payload, err := json.Marshal(map[string]string{"profile_id": profileID.String()})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = r.db.Exec(ctx, `
|
|
insert into settings (category, key, value, updated_at)
|
|
values ('device_access_profile', $1, $2::jsonb, now())
|
|
on conflict (category, key)
|
|
do update set value = excluded.value, updated_at = now()
|
|
`, deviceID.String(), string(payload))
|
|
return err
|
|
}
|
|
|
|
func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
|
|
rows, err := r.db.Query(ctx, `
|
|
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
|
|
from devices d
|
|
left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
|
where d.user_id = $1 and d.deleted_at is null
|
|
order by d.created_at desc
|
|
`, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []Device
|
|
for rows.Next() {
|
|
var item Device
|
|
if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return r.applyRuntimeStats(ctx, items)
|
|
}
|
|
|
|
func (r *PGRepository) ListAll(ctx context.Context) ([]Device, error) {
|
|
rows, err := r.db.Query(ctx, `
|
|
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
|
|
from devices d
|
|
left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
|
where d.deleted_at is null
|
|
order by d.created_at desc
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []Device
|
|
for rows.Next() {
|
|
var item Device
|
|
if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return r.applyRuntimeStats(ctx, items)
|
|
}
|
|
|
|
func (r *PGRepository) Revoke(ctx context.Context, deviceID uuid.UUID) error {
|
|
tx, err := r.db.Begin(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
if _, err := tx.Exec(ctx, `update devices set status = 'revoked', revoked_at = now(), updated_at = now() where id = $1`, deviceID); err != nil {
|
|
return err
|
|
}
|
|
if _, err := tx.Exec(ctx, `update wireguard_peers set deleted_at = now(), updated_at = now() where device_id = $1 and deleted_at is null`, deviceID); err != nil {
|
|
return err
|
|
}
|
|
if _, err := tx.Exec(ctx, `update ip_allocations set status = 'released', released_at = now(), updated_at = now() where device_id = $1 and status = 'allocated'`, deviceID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func (r *PGRepository) Rotate(ctx context.Context, deviceID uuid.UUID) error {
|
|
_, err := r.db.Exec(ctx, `
|
|
update wireguard_peers
|
|
set profile_revision = profile_revision + 1, last_profile_issued_at = now(), updated_at = now()
|
|
where device_id = $1 and deleted_at is null
|
|
`, deviceID)
|
|
return err
|
|
}
|
|
|
|
type enrollmentRowScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
type runtimeSnapshot struct {
|
|
Peers []runtimePeer `json:"peers"`
|
|
}
|
|
|
|
type runtimePeer struct {
|
|
DeviceID string `json:"device_id"`
|
|
RXBytes uint64 `json:"rx_bytes"`
|
|
TXBytes uint64 `json:"tx_bytes"`
|
|
}
|
|
|
|
func (r *PGRepository) applyRuntimeStats(ctx context.Context, items []Device) ([]Device, error) {
|
|
if len(items) == 0 {
|
|
return items, nil
|
|
}
|
|
|
|
rows, err := r.db.Query(ctx, `
|
|
select value
|
|
from settings
|
|
where category = 'gateway_runtime'
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
statsByDevice := make(map[uuid.UUID]runtimePeer)
|
|
for rows.Next() {
|
|
var raw []byte
|
|
if err := rows.Scan(&raw); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var snapshot runtimeSnapshot
|
|
if err := json.Unmarshal(raw, &snapshot); err != nil {
|
|
continue
|
|
}
|
|
|
|
for _, peer := range snapshot.Peers {
|
|
deviceID, err := uuid.Parse(peer.DeviceID)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
existing := statsByDevice[deviceID]
|
|
existing.DeviceID = peer.DeviceID
|
|
existing.RXBytes += peer.RXBytes
|
|
existing.TXBytes += peer.TXBytes
|
|
statsByDevice[deviceID] = existing
|
|
}
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for index := range items {
|
|
if stats, ok := statsByDevice[items[index].ID]; ok {
|
|
items[index].RXBytes = stats.RXBytes
|
|
items[index].TXBytes = stats.TXBytes
|
|
}
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
func scanEnrollmentRow(row enrollmentRowScanner) (EnrollmentResponse, error) {
|
|
var response EnrollmentResponse
|
|
var profileRevision int
|
|
var lastIssuedAt *time.Time
|
|
var gatewayName string
|
|
var gatewayEndpoint string
|
|
var gatewayPublicKey string
|
|
var dnsServers []string
|
|
var allowedIPs []string
|
|
|
|
if err := row.Scan(
|
|
&response.Device.ID,
|
|
&response.Device.UserID,
|
|
&response.Device.GatewayID,
|
|
&response.Device.Name,
|
|
&response.Device.Platform,
|
|
&response.Device.Status,
|
|
&response.Device.AssignedIP,
|
|
&profileRevision,
|
|
&lastIssuedAt,
|
|
&gatewayName,
|
|
&gatewayEndpoint,
|
|
&gatewayPublicKey,
|
|
&dnsServers,
|
|
&allowedIPs,
|
|
); err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
response.Peer = PeerView{
|
|
AssignedIP: response.Device.AssignedIP,
|
|
DNSServers: dnsServers,
|
|
AllowedIPs: allowedIPs,
|
|
Gateway: GatewayView{
|
|
ID: response.Device.GatewayID,
|
|
Name: gatewayName,
|
|
Endpoint: gatewayEndpoint,
|
|
PublicKey: gatewayPublicKey,
|
|
},
|
|
ProfileRevision: profileRevision,
|
|
}
|
|
for _, destination := range allowedIPs {
|
|
response.Resources = append(response.Resources, Resource{
|
|
Type: "cidr",
|
|
Value: destination,
|
|
Label: destination,
|
|
})
|
|
}
|
|
_ = lastIssuedAt
|
|
return response, nil
|
|
}
|