package device import ( "context" "fmt" "os" "strings" "github.com/google/uuid" "nexavpn/backend/internal/gateway" "nexavpn/backend/internal/ipam" "nexavpn/backend/internal/policy" "nexavpn/backend/internal/profile" ) type Service struct { repo Repository policyService *policy.Service gatewayService *gateway.Service ipamService *ipam.Service } func NewService(repo Repository, policyService *policy.Service, gatewayService *gateway.Service, ipamService *ipam.Service) *Service { return &Service{ repo: repo, policyService: policyService, gatewayService: gatewayService, ipamService: ipamService, } } func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequest, privateKeyPlaceholder string) (EnrollmentResponse, error) { selectedGateway, err := s.gatewayService.SelectActive(ctx) if err != nil { return EnrollmentResponse{}, err } assignedIP, err := s.repo.FindNextAvailableIP(ctx, selectedGateway.ID, selectedGateway.VPNCIDR, 10) if err != nil { assignedIP, err = s.ipamService.Allocate(selectedGateway.VPNCIDR, 10) if err != nil { return EnrollmentResponse{}, err } } destinations, err := s.policyService.ResolveDestinations(ctx, userID, nil) if err != nil { return EnrollmentResponse{}, err } if len(destinations) == 0 { destinations = []string{"172.16.10.0/24"} } enrollment, err := s.repo.Enroll(ctx, userID, selectedGateway.ID, input, assignedIP, selectedGateway.DNSServers, destinations) if err != nil { return EnrollmentResponse{}, err } destinations, err = s.policyService.ResolveDestinations(ctx, userID, &enrollment.Device.ID) if err != nil { return EnrollmentResponse{}, err } if len(destinations) == 0 { destinations = []string{"172.16.10.0/24"} } availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID) if err != nil { return EnrollmentResponse{}, err } if len(selectedDestinations) == 0 { selectedDestinations = destinations } profileAllowedIPs := mergeProfileAllowedIPs(selectedDestinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets()) enrollment.Peer = PeerView{ AssignedIP: assignedIP, DNSServers: selectedGateway.DNSServers, AllowedIPs: profileAllowedIPs, Gateway: GatewayView{ ID: selectedGateway.ID, Name: selectedGateway.Name, Endpoint: selectedGateway.Endpoint, PublicKey: selectedGateway.PublicKey, }, ProfileRevision: 1, } enrollment.Resources = resourcesFromDestinations(selectedDestinations) enrollment.AvailableProfiles = availableProfiles enrollment.SelectedProfileID = selectedProfileID enrollment.Profile = ProfileView{ Format: "wireguard", Content: profile.BuildWireGuardConfig(profile.BuildInput{ PrivateKey: privateKeyPlaceholder, Address: assignedIP, DNSServers: selectedGateway.DNSServers, ServerPublicKey: selectedGateway.PublicKey, ServerEndpoint: selectedGateway.Endpoint, AllowedIPs: profileAllowedIPs, PersistentKeepal: 25, }), } return enrollment, nil } func (s *Service) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) { return s.repo.ListByUser(ctx, userID) } func (s *Service) ListAll(ctx context.Context) ([]Device, error) { return s.repo.ListAll(ctx) } func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) { enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID) if err != nil { return EnrollmentResponse{}, err } return s.applyCurrentPolicy(ctx, enrollment) } func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) { enrollment, err := s.repo.GetEnrollmentByDeviceID(ctx, deviceID) if err != nil { return EnrollmentResponse{}, err } return s.applyCurrentPolicy(ctx, enrollment) } func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) { enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID) if err != nil { return EnrollmentResponse{}, err } profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID) if err != nil { return EnrollmentResponse{}, err } var exists bool for _, profile := range profiles { if profile.ID == profileID { exists = true break } } if !exists { return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device") } if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil { return EnrollmentResponse{}, err } return s.applyCurrentPolicy(ctx, enrollment) } func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) { enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID) if err != nil { return ConnectionStatus{ Status: "disconnected", Resources: []Resource{}, }, nil } enrollment, err = s.applyCurrentPolicy(ctx, enrollment) if err != nil { return ConnectionStatus{}, err } lastSync := "just now" return ConnectionStatus{ Status: "provisioned", AssignedIP: enrollment.Peer.AssignedIP, LastSyncTime: &lastSync, Resources: enrollment.Resources, }, nil } func (s *Service) Revoke(ctx context.Context, deviceID uuid.UUID) error { return s.repo.Revoke(ctx, deviceID) } func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error { return s.repo.Rotate(ctx, deviceID) } func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse { profileAllowedIPs := enrollment.Peer.AllowedIPs enrollment.Profile = ProfileView{ Format: "wireguard", Content: profile.BuildWireGuardConfig(profile.BuildInput{ PrivateKey: "__CLIENT_PRIVATE_KEY_REQUIRED__", Address: enrollment.Peer.AssignedIP, DNSServers: enrollment.Peer.DNSServers, ServerPublicKey: enrollment.Peer.Gateway.PublicKey, ServerEndpoint: enrollment.Peer.Gateway.Endpoint, AllowedIPs: profileAllowedIPs, PersistentKeepal: 25, }), } return enrollment } func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) { availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID) if err != nil { return EnrollmentResponse{}, err } if len(selectedDestinations) == 0 { selectedDestinations = []string{"172.16.10.0/24"} } enrollment.Resources = resourcesFromDestinations(selectedDestinations) enrollment.AvailableProfiles = availableProfiles enrollment.SelectedProfileID = selectedProfileID enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(selectedDestinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets()) return withDebugProfile(enrollment), nil } func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) { profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID) if err != nil { return nil, nil, nil, err } availableProfiles := make([]AccessProfile, 0, len(profiles)) for _, profile := range profiles { availableProfiles = append(availableProfiles, AccessProfile{ ID: profile.ID, Name: profile.Name, Description: profile.Description, FullTunnel: profile.FullTunnel, Destinations: profile.Destinations, }) } if len(availableProfiles) == 0 { return nil, nil, nil, nil } selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID) if err != nil { return nil, nil, nil, err } for _, profile := range availableProfiles { if selectedProfileID != nil && profile.ID == *selectedProfileID { return availableProfiles, selectedProfileID, profile.Destinations, nil } } fallback := availableProfiles[0] if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil { return nil, nil, nil, err } return availableProfiles, &fallback.ID, fallback.Destinations, nil } func resourcesFromDestinations(destinations []string) []Resource { resources := make([]Resource, 0, len(destinations)) for _, destination := range destinations { resources = append(resources, Resource{ Type: "cidr", Value: destination, Label: destination, }) } return resources } func mergeProfileAllowedIPs(destinations []string, dnsServers []string, webProxyTargets []string) []string { seen := make(map[string]struct{}, len(destinations)+len(dnsServers)+len(webProxyTargets)) merged := make([]string, 0, len(destinations)+len(dnsServers)+len(webProxyTargets)) for _, destination := range destinations { destination = strings.TrimSpace(destination) if destination == "" { continue } if _, exists := seen[destination]; exists { continue } seen[destination] = struct{}{} merged = append(merged, destination) } for _, dnsServer := range dnsServers { route := dnsServerRoute(dnsServer) if route == "" { continue } if _, exists := seen[route]; exists { continue } seen[route] = struct{}{} merged = append(merged, route) } for _, target := range webProxyTargets { route := dnsServerRoute(target) if route == "" { continue } if _, exists := seen[route]; exists { continue } seen[route] = struct{}{} merged = append(merged, route) } return merged } func dnsServerRoute(value string) string { value = strings.TrimSpace(value) if value == "" { return "" } if strings.Contains(value, "/") { return value } return value + "/32" } func alwaysAllowWebProxyTargets() []string { raw := os.Getenv("NEXAVPN_ALWAYS_ALLOW_WEB_PROXY_IPS") if strings.TrimSpace(raw) == "" { return nil } seen := make(map[string]struct{}) targets := make([]string, 0) for _, part := range strings.Split(raw, ",") { value := strings.TrimSpace(part) if value == "" { continue } if _, ok := seen[value]; ok { continue } seen[value] = struct{}{} targets = append(targets, value) } return targets }