Add dnsServersForProfile and dnsServersForPeer helpers with conditional DNS server selection based on service presence. Use NEXAVPN_CLIENT_DNS_SERVERS override when services are configured, otherwise fall back to DEFAULT_DNS_SERVERS or gateway base DNS servers. Replace direct gateway DNS server usage in Enroll and applyCurrentPolicy with profileDNSServers variable. Update BuildSyncBundle to scan gateway DNS servers separately
438 lines
12 KiB
Go
438 lines
12 KiB
Go
package device
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"nexavpn/backend/internal/gateway"
|
|
"nexavpn/backend/internal/ipam"
|
|
"nexavpn/backend/internal/policy"
|
|
"nexavpn/backend/internal/profile"
|
|
)
|
|
|
|
type Service struct {
|
|
repo Repository
|
|
policyService *policy.Service
|
|
gatewayService *gateway.Service
|
|
ipamService *ipam.Service
|
|
}
|
|
|
|
func NewService(repo Repository, policyService *policy.Service, gatewayService *gateway.Service, ipamService *ipam.Service) *Service {
|
|
return &Service{
|
|
repo: repo,
|
|
policyService: policyService,
|
|
gatewayService: gatewayService,
|
|
ipamService: ipamService,
|
|
}
|
|
}
|
|
|
|
func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequest, privateKeyPlaceholder string) (EnrollmentResponse, error) {
|
|
selectedGateway, err := s.gatewayService.SelectActive(ctx)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
assignedIP, err := s.repo.FindNextAvailableIP(ctx, selectedGateway.ID, selectedGateway.VPNCIDR, 10)
|
|
if err != nil {
|
|
assignedIP, err = s.ipamService.Allocate(selectedGateway.VPNCIDR, 10)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
}
|
|
|
|
destinations, err := s.policyService.ResolveDestinations(ctx, userID, nil)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
if len(destinations) == 0 {
|
|
destinations = []string{"172.16.10.0/24"}
|
|
}
|
|
|
|
enrollment, err := s.repo.Enroll(ctx, userID, selectedGateway.ID, input, assignedIP, selectedGateway.DNSServers, destinations)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
destinations, err = s.policyService.ResolveDestinations(ctx, userID, &enrollment.Device.ID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
if len(selectedDestinations) == 0 && len(servicesForSelectedProfile(availableProfiles, selectedProfileID)) == 0 {
|
|
selectedDestinations = []string{"172.16.10.0/24"}
|
|
}
|
|
if len(selectedDestinations) == 0 {
|
|
selectedDestinations = destinations
|
|
}
|
|
selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
|
|
profileDNSServers := dnsServersForProfile(selectedGateway.DNSServers, selectedServices)
|
|
profileAllowedIPs := mergeProfileAllowedIPs(
|
|
append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
|
|
profileDNSServers,
|
|
)
|
|
|
|
enrollment.Peer = PeerView{
|
|
AssignedIP: assignedIP,
|
|
DNSServers: profileDNSServers,
|
|
AllowedIPs: profileAllowedIPs,
|
|
Gateway: GatewayView{
|
|
ID: selectedGateway.ID,
|
|
Name: selectedGateway.Name,
|
|
Endpoint: selectedGateway.Endpoint,
|
|
PublicKey: selectedGateway.PublicKey,
|
|
},
|
|
ProfileRevision: 1,
|
|
}
|
|
enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
|
|
enrollment.AvailableProfiles = availableProfiles
|
|
enrollment.SelectedProfileID = selectedProfileID
|
|
|
|
enrollment.Profile = ProfileView{
|
|
Format: "wireguard",
|
|
Content: profile.BuildWireGuardConfig(profile.BuildInput{
|
|
PrivateKey: privateKeyPlaceholder,
|
|
Address: assignedIP,
|
|
DNSServers: profileDNSServers,
|
|
ServerPublicKey: selectedGateway.PublicKey,
|
|
ServerEndpoint: selectedGateway.Endpoint,
|
|
AllowedIPs: profileAllowedIPs,
|
|
PersistentKeepal: 25,
|
|
}),
|
|
}
|
|
|
|
return enrollment, nil
|
|
}
|
|
|
|
func (s *Service) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
|
|
return s.repo.ListByUser(ctx, userID)
|
|
}
|
|
|
|
func (s *Service) ListAll(ctx context.Context) ([]Device, error) {
|
|
return s.repo.ListAll(ctx)
|
|
}
|
|
|
|
func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
|
|
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
return s.applyCurrentPolicy(ctx, enrollment)
|
|
}
|
|
|
|
func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
|
|
enrollment, err := s.repo.GetEnrollmentByDeviceID(ctx, deviceID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
return s.applyCurrentPolicy(ctx, enrollment)
|
|
}
|
|
|
|
func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) {
|
|
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
var exists bool
|
|
for _, profile := range profiles {
|
|
if profile.ID == profileID {
|
|
exists = true
|
|
break
|
|
}
|
|
}
|
|
if !exists {
|
|
return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device")
|
|
}
|
|
|
|
if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
return s.applyCurrentPolicy(ctx, enrollment)
|
|
}
|
|
|
|
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
|
|
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
|
if err != nil {
|
|
return ConnectionStatus{
|
|
Status: "disconnected",
|
|
Resources: []Resource{},
|
|
}, nil
|
|
}
|
|
|
|
enrollment, err = s.applyCurrentPolicy(ctx, enrollment)
|
|
if err != nil {
|
|
return ConnectionStatus{}, err
|
|
}
|
|
|
|
lastSync := "just now"
|
|
return ConnectionStatus{
|
|
Status: "provisioned",
|
|
AssignedIP: enrollment.Peer.AssignedIP,
|
|
LastSyncTime: &lastSync,
|
|
Resources: enrollment.Resources,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) Revoke(ctx context.Context, deviceID uuid.UUID) error {
|
|
return s.repo.Revoke(ctx, deviceID)
|
|
}
|
|
|
|
func (s *Service) Delete(ctx context.Context, deviceID uuid.UUID) error {
|
|
return s.repo.Delete(ctx, deviceID)
|
|
}
|
|
|
|
func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error {
|
|
return s.repo.Rotate(ctx, deviceID)
|
|
}
|
|
|
|
func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
|
|
profileAllowedIPs := enrollment.Peer.AllowedIPs
|
|
enrollment.Profile = ProfileView{
|
|
Format: "wireguard",
|
|
Content: profile.BuildWireGuardConfig(profile.BuildInput{
|
|
PrivateKey: "__CLIENT_PRIVATE_KEY_REQUIRED__",
|
|
Address: enrollment.Peer.AssignedIP,
|
|
DNSServers: enrollment.Peer.DNSServers,
|
|
ServerPublicKey: enrollment.Peer.Gateway.PublicKey,
|
|
ServerEndpoint: enrollment.Peer.Gateway.Endpoint,
|
|
AllowedIPs: profileAllowedIPs,
|
|
PersistentKeepal: 25,
|
|
}),
|
|
}
|
|
return enrollment
|
|
}
|
|
|
|
func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) {
|
|
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
|
|
if len(selectedDestinations) == 0 && len(selectedServices) == 0 {
|
|
selectedDestinations = []string{"172.16.10.0/24"}
|
|
}
|
|
enrollment.Peer.DNSServers = dnsServersForProfile(enrollment.Peer.DNSServers, selectedServices)
|
|
|
|
enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
|
|
enrollment.AvailableProfiles = availableProfiles
|
|
enrollment.SelectedProfileID = selectedProfileID
|
|
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(
|
|
append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
|
|
enrollment.Peer.DNSServers,
|
|
)
|
|
return withDebugProfile(enrollment), nil
|
|
}
|
|
|
|
func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) {
|
|
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
availableProfiles := make([]AccessProfile, 0, len(profiles))
|
|
for _, profile := range profiles {
|
|
services := make([]AccessService, 0, len(profile.Services))
|
|
for _, service := range profile.Services {
|
|
services = append(services, AccessService{
|
|
ID: service.ID,
|
|
Name: service.Name,
|
|
Description: service.Description,
|
|
Domain: service.Domain,
|
|
UpstreamIP: service.UpstreamIP,
|
|
ProxyIP: service.ProxyIP,
|
|
Ports: service.Ports,
|
|
})
|
|
}
|
|
availableProfiles = append(availableProfiles, AccessProfile{
|
|
ID: profile.ID,
|
|
Name: profile.Name,
|
|
Description: profile.Description,
|
|
FullTunnel: profile.FullTunnel,
|
|
Destinations: profile.Destinations,
|
|
Services: services,
|
|
})
|
|
}
|
|
|
|
if len(availableProfiles) == 0 {
|
|
return nil, nil, nil, nil
|
|
}
|
|
|
|
selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
for _, profile := range availableProfiles {
|
|
if selectedProfileID != nil && profile.ID == *selectedProfileID {
|
|
return availableProfiles, selectedProfileID, profile.Destinations, nil
|
|
}
|
|
}
|
|
|
|
fallback := availableProfiles[0]
|
|
if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
return availableProfiles, &fallback.ID, fallback.Destinations, nil
|
|
}
|
|
|
|
func resourcesFromDestinations(destinations []string) []Resource {
|
|
resources := make([]Resource, 0, len(destinations))
|
|
for _, destination := range destinations {
|
|
resources = append(resources, Resource{
|
|
Type: "cidr",
|
|
Value: destination,
|
|
Label: destination,
|
|
})
|
|
}
|
|
return resources
|
|
}
|
|
|
|
func resourcesFromProfile(destinations []string, services []AccessService) []Resource {
|
|
resources := resourcesFromDestinations(destinations)
|
|
for _, service := range services {
|
|
resources = append(resources, Resource{
|
|
Type: "service",
|
|
Value: service.Domain,
|
|
Label: service.Name,
|
|
Domain: service.Domain,
|
|
})
|
|
}
|
|
return resources
|
|
}
|
|
|
|
func servicesForSelectedProfile(profiles []AccessProfile, selectedProfileID *uuid.UUID) []AccessService {
|
|
if selectedProfileID == nil {
|
|
if len(profiles) == 0 {
|
|
return nil
|
|
}
|
|
return profiles[0].Services
|
|
}
|
|
|
|
for _, profile := range profiles {
|
|
if profile.ID == *selectedProfileID {
|
|
return profile.Services
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func proxyRoutesForServices(services []AccessService) []string {
|
|
seen := make(map[string]struct{}, len(services))
|
|
routes := make([]string, 0, len(services))
|
|
for _, service := range services {
|
|
route := dnsServerRoute(effectiveServiceProxyIP(service.ProxyIP))
|
|
if route == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[route]; ok {
|
|
continue
|
|
}
|
|
seen[route] = struct{}{}
|
|
routes = append(routes, route)
|
|
}
|
|
return routes
|
|
}
|
|
|
|
func effectiveServiceProxyIP(proxyIP string) string {
|
|
override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP"))
|
|
if override != "" {
|
|
return override
|
|
}
|
|
return proxyIP
|
|
}
|
|
|
|
func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string {
|
|
seen := make(map[string]struct{}, len(destinations)+len(dnsServers))
|
|
merged := make([]string, 0, len(destinations)+len(dnsServers))
|
|
|
|
for _, destination := range destinations {
|
|
destination = strings.TrimSpace(destination)
|
|
if destination == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[destination]; exists {
|
|
continue
|
|
}
|
|
seen[destination] = struct{}{}
|
|
merged = append(merged, destination)
|
|
}
|
|
|
|
for _, dnsServer := range dnsServers {
|
|
route := dnsServerRoute(dnsServer)
|
|
if route == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[route]; exists {
|
|
continue
|
|
}
|
|
seen[route] = struct{}{}
|
|
merged = append(merged, route)
|
|
}
|
|
|
|
return merged
|
|
}
|
|
|
|
func dnsServerRoute(value string) string {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return ""
|
|
}
|
|
if strings.Contains(value, "/") {
|
|
return value
|
|
}
|
|
return value + "/32"
|
|
}
|
|
|
|
func dnsServersForProfile(base []string, services []AccessService) []string {
|
|
if len(services) > 0 {
|
|
if override := parseEnvList("NEXAVPN_CLIENT_DNS_SERVERS"); len(override) > 0 {
|
|
return override
|
|
}
|
|
return dedupeList(base)
|
|
}
|
|
|
|
if override := parseEnvList("DEFAULT_DNS_SERVERS"); len(override) > 0 {
|
|
return override
|
|
}
|
|
return dedupeList(base)
|
|
}
|
|
|
|
func parseEnvList(key string) []string {
|
|
return parseCommaList(os.Getenv(key))
|
|
}
|
|
|
|
func parseCommaList(raw string) []string {
|
|
seen := make(map[string]struct{})
|
|
values := make([]string, 0)
|
|
for _, part := range strings.Split(raw, ",") {
|
|
value := strings.TrimSpace(part)
|
|
if value == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[value]; ok {
|
|
continue
|
|
}
|
|
seen[value] = struct{}{}
|
|
values = append(values, value)
|
|
}
|
|
return values
|
|
}
|
|
|
|
func dedupeList(values []string) []string {
|
|
return parseCommaList(strings.Join(values, ","))
|
|
}
|