package device import ( "context" "fmt" "os" "strings" "github.com/google/uuid" "nexavpn/backend/internal/gateway" "nexavpn/backend/internal/ipam" "nexavpn/backend/internal/policy" "nexavpn/backend/internal/profile" ) type Service struct { repo Repository policyService *policy.Service gatewayService *gateway.Service ipamService *ipam.Service } func NewService(repo Repository, policyService *policy.Service, gatewayService *gateway.Service, ipamService *ipam.Service) *Service { return &Service{ repo: repo, policyService: policyService, gatewayService: gatewayService, ipamService: ipamService, } } func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequest, privateKeyPlaceholder string) (EnrollmentResponse, error) { selectedGateway, err := s.gatewayService.SelectActive(ctx) if err != nil { return EnrollmentResponse{}, err } assignedIP, err := s.repo.FindNextAvailableIP(ctx, selectedGateway.ID, selectedGateway.VPNCIDR, 10) if err != nil { assignedIP, err = s.ipamService.Allocate(selectedGateway.VPNCIDR, 10) if err != nil { return EnrollmentResponse{}, err } } destinations, err := s.policyService.ResolveDestinations(ctx, userID, nil) if err != nil { return EnrollmentResponse{}, err } if len(destinations) == 0 { destinations = []string{"172.16.10.0/24"} } enrollment, err := s.repo.Enroll(ctx, userID, selectedGateway.ID, input, assignedIP, selectedGateway.DNSServers, destinations) if err != nil { return EnrollmentResponse{}, err } destinations, err = s.policyService.ResolveDestinations(ctx, userID, &enrollment.Device.ID) if err != nil { return EnrollmentResponse{}, err } availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID) if err != nil { return EnrollmentResponse{}, err } if len(selectedDestinations) == 0 && len(servicesForSelectedProfile(availableProfiles, selectedProfileID)) == 0 { selectedDestinations = []string{"172.16.10.0/24"} } if len(selectedDestinations) == 0 { selectedDestinations = destinations } selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID) profileDNSServers := dnsServersForProfile(selectedGateway.DNSServers, selectedServices) profileAllowedIPs := mergeProfileAllowedIPs( append(selectedDestinations, proxyRoutesForServices(selectedServices)...), profileDNSServers, ) enrollment.Peer = PeerView{ AssignedIP: assignedIP, DNSServers: profileDNSServers, AllowedIPs: profileAllowedIPs, Gateway: GatewayView{ ID: selectedGateway.ID, Name: selectedGateway.Name, Endpoint: selectedGateway.Endpoint, PublicKey: selectedGateway.PublicKey, }, ProfileRevision: 1, } enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices) enrollment.AvailableProfiles = availableProfiles enrollment.SelectedProfileID = selectedProfileID enrollment.Profile = ProfileView{ Format: "wireguard", Content: profile.BuildWireGuardConfig(profile.BuildInput{ PrivateKey: privateKeyPlaceholder, Address: assignedIP, DNSServers: profileDNSServers, ServerPublicKey: selectedGateway.PublicKey, ServerEndpoint: selectedGateway.Endpoint, AllowedIPs: profileAllowedIPs, PersistentKeepal: 25, }), } return enrollment, nil } func (s *Service) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) { return s.repo.ListByUser(ctx, userID) } func (s *Service) ListAll(ctx context.Context) ([]Device, error) { return s.repo.ListAll(ctx) } func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) { enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID) if err != nil { return EnrollmentResponse{}, err } return s.applyCurrentPolicy(ctx, enrollment) } func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) { enrollment, err := s.repo.GetEnrollmentByDeviceID(ctx, deviceID) if err != nil { return EnrollmentResponse{}, err } return s.applyCurrentPolicy(ctx, enrollment) } func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) { enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID) if err != nil { return EnrollmentResponse{}, err } profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID) if err != nil { return EnrollmentResponse{}, err } var exists bool for _, profile := range profiles { if profile.ID == profileID { exists = true break } } if !exists { return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device") } if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil { return EnrollmentResponse{}, err } return s.applyCurrentPolicy(ctx, enrollment) } func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) { enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID) if err != nil { return ConnectionStatus{ Status: "disconnected", Resources: []Resource{}, }, nil } enrollment, err = s.applyCurrentPolicy(ctx, enrollment) if err != nil { return ConnectionStatus{}, err } lastSync := "just now" return ConnectionStatus{ Status: "provisioned", AssignedIP: enrollment.Peer.AssignedIP, LastSyncTime: &lastSync, Resources: enrollment.Resources, }, nil } func (s *Service) Revoke(ctx context.Context, deviceID uuid.UUID) error { return s.repo.Revoke(ctx, deviceID) } func (s *Service) Delete(ctx context.Context, deviceID uuid.UUID) error { return s.repo.Delete(ctx, deviceID) } func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error { return s.repo.Rotate(ctx, deviceID) } func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse { profileAllowedIPs := enrollment.Peer.AllowedIPs enrollment.Profile = ProfileView{ Format: "wireguard", Content: profile.BuildWireGuardConfig(profile.BuildInput{ PrivateKey: "__CLIENT_PRIVATE_KEY_REQUIRED__", Address: enrollment.Peer.AssignedIP, DNSServers: enrollment.Peer.DNSServers, ServerPublicKey: enrollment.Peer.Gateway.PublicKey, ServerEndpoint: enrollment.Peer.Gateway.Endpoint, AllowedIPs: profileAllowedIPs, PersistentKeepal: 25, }), } return enrollment } func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) { availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID) if err != nil { return EnrollmentResponse{}, err } selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID) if len(selectedDestinations) == 0 && len(selectedServices) == 0 { selectedDestinations = []string{"172.16.10.0/24"} } enrollment.Peer.DNSServers = dnsServersForProfile(enrollment.Peer.DNSServers, selectedServices) enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices) enrollment.AvailableProfiles = availableProfiles enrollment.SelectedProfileID = selectedProfileID enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs( append(selectedDestinations, proxyRoutesForServices(selectedServices)...), enrollment.Peer.DNSServers, ) return withDebugProfile(enrollment), nil } func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) { profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID) if err != nil { return nil, nil, nil, err } availableProfiles := make([]AccessProfile, 0, len(profiles)) for _, profile := range profiles { services := make([]AccessService, 0, len(profile.Services)) for _, service := range profile.Services { services = append(services, AccessService{ ID: service.ID, Name: service.Name, Description: service.Description, Domain: service.Domain, UpstreamIP: service.UpstreamIP, ProxyIP: service.ProxyIP, Ports: service.Ports, }) } availableProfiles = append(availableProfiles, AccessProfile{ ID: profile.ID, Name: profile.Name, Description: profile.Description, FullTunnel: profile.FullTunnel, Destinations: profile.Destinations, Services: services, }) } if len(availableProfiles) == 0 { return nil, nil, nil, nil } selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID) if err != nil { return nil, nil, nil, err } for _, profile := range availableProfiles { if selectedProfileID != nil && profile.ID == *selectedProfileID { return availableProfiles, selectedProfileID, profile.Destinations, nil } } fallback := availableProfiles[0] if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil { return nil, nil, nil, err } return availableProfiles, &fallback.ID, fallback.Destinations, nil } func resourcesFromDestinations(destinations []string) []Resource { resources := make([]Resource, 0, len(destinations)) for _, destination := range destinations { resources = append(resources, Resource{ Type: "cidr", Value: destination, Label: destination, }) } return resources } func resourcesFromProfile(destinations []string, services []AccessService) []Resource { resources := resourcesFromDestinations(destinations) for _, service := range services { resources = append(resources, Resource{ Type: "service", Value: service.Domain, Label: service.Name, Domain: service.Domain, }) } return resources } func servicesForSelectedProfile(profiles []AccessProfile, selectedProfileID *uuid.UUID) []AccessService { if selectedProfileID == nil { if len(profiles) == 0 { return nil } return profiles[0].Services } for _, profile := range profiles { if profile.ID == *selectedProfileID { return profile.Services } } return nil } func proxyRoutesForServices(services []AccessService) []string { seen := make(map[string]struct{}, len(services)) routes := make([]string, 0, len(services)) for _, service := range services { route := dnsServerRoute(effectiveServiceProxyIP(service.ProxyIP)) if route == "" { continue } if _, ok := seen[route]; ok { continue } seen[route] = struct{}{} routes = append(routes, route) } return routes } func effectiveServiceProxyIP(proxyIP string) string { override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP")) if override != "" { return override } return proxyIP } func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string { seen := make(map[string]struct{}, len(destinations)+len(dnsServers)) merged := make([]string, 0, len(destinations)+len(dnsServers)) for _, destination := range destinations { destination = strings.TrimSpace(destination) if destination == "" { continue } if _, exists := seen[destination]; exists { continue } seen[destination] = struct{}{} merged = append(merged, destination) } for _, dnsServer := range dnsServers { route := dnsServerRoute(dnsServer) if route == "" { continue } if _, exists := seen[route]; exists { continue } seen[route] = struct{}{} merged = append(merged, route) } return merged } func dnsServerRoute(value string) string { value = strings.TrimSpace(value) if value == "" { return "" } if strings.Contains(value, "/") { return value } return value + "/32" } func dnsServersForProfile(base []string, services []AccessService) []string { if len(services) > 0 { if override := parseEnvList("NEXAVPN_CLIENT_DNS_SERVERS"); len(override) > 0 { return override } return dedupeList(base) } if override := parseEnvList("DEFAULT_DNS_SERVERS"); len(override) > 0 { return override } return dedupeList(base) } func parseEnvList(key string) []string { return parseCommaList(os.Getenv(key)) } func parseCommaList(raw string) []string { seen := make(map[string]struct{}) values := make([]string, 0) for _, part := range strings.Split(raw, ",") { value := strings.TrimSpace(part) if value == "" { continue } if _, ok := seen[value]; ok { continue } seen[value] = struct{}{} values = append(values, value) } return values } func dedupeList(values []string) []string { return parseCommaList(strings.Join(values, ",")) }