diff --git a/backend/internal/device/repository.go b/backend/internal/device/repository.go index 589a816..c9141bc 100644 --- a/backend/internal/device/repository.go +++ b/backend/internal/device/repository.go @@ -2,6 +2,8 @@ package device import ( "context" + "errors" + "net/netip" "time" "github.com/google/uuid" @@ -9,6 +11,7 @@ import ( ) type Repository interface { + FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) ListAll(ctx context.Context) ([]Device, error) @@ -26,6 +29,49 @@ func NewPGRepository(db *pgxpool.Pool) *PGRepository { return &PGRepository{db: db} } +func (r *PGRepository) FindNextAvailableIP(ctx context.Context, gatewayID uuid.UUID, vpnCIDR string, startOffset int) (string, error) { + prefix, err := netip.ParsePrefix(vpnCIDR) + if err != nil { + return "", err + } + + rows, err := r.db.Query(ctx, ` + select host(address) + from ip_allocations + where gateway_id = $1 + `, gatewayID) + if err != nil { + return "", err + } + defer rows.Close() + + used := map[string]struct{}{} + for rows.Next() { + var address string + if err := rows.Scan(&address); err != nil { + return "", err + } + used[address] = struct{}{} + } + if err := rows.Err(); err != nil { + return "", err + } + + address := prefix.Addr().Next() + for i := 1; i < startOffset; i++ { + address = address.Next() + } + + for prefix.Contains(address) { + if _, exists := used[address.String()]; !exists { + return address.String() + "/32", nil + } + address = address.Next() + } + + return "", errors.New("no available ip addresses for gateway") +} + func (r *PGRepository) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) { tx, err := r.db.Begin(ctx) if err != nil { diff --git a/backend/internal/device/service.go b/backend/internal/device/service.go index 6f733a4..2e89377 100644 --- a/backend/internal/device/service.go +++ b/backend/internal/device/service.go @@ -33,9 +33,12 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ return EnrollmentResponse{}, err } - assignedIP, err := s.ipamService.Allocate(selectedGateway.VPNCIDR, 10) + assignedIP, err := s.repo.FindNextAvailableIP(ctx, selectedGateway.ID, selectedGateway.VPNCIDR, 10) if err != nil { - return EnrollmentResponse{}, err + assignedIP, err = s.ipamService.Allocate(selectedGateway.VPNCIDR, 10) + if err != nil { + return EnrollmentResponse{}, err + } } destinations, err := s.policyService.ResolveDestinations(ctx, userID, nil)