Files
NexaVPN/backend/internal/device/service.go
nessi 5d5f736e1b refactor: move default destination fallback after profile resolution and add nftables input chain filtering for VPN clients
Move default 172.16.10.0/24 destination assignment to after profile resolution and only apply when both selectedDestinations and services are empty. Extract selectedServices calculation before conditional check in applyCurrentPolicy.

Add nftables input chain to gateway with per-peer filtering. Accept established connections and non-WireGuard traffic. Allow DNS queries to configured
2026-03-19 22:26:03 +01:00

393 lines
11 KiB
Go

package device
import (
"context"
"fmt"
"os"
"strings"
"github.com/google/uuid"
"nexavpn/backend/internal/gateway"
"nexavpn/backend/internal/ipam"
"nexavpn/backend/internal/policy"
"nexavpn/backend/internal/profile"
)
type Service struct {
repo Repository
policyService *policy.Service
gatewayService *gateway.Service
ipamService *ipam.Service
}
func NewService(repo Repository, policyService *policy.Service, gatewayService *gateway.Service, ipamService *ipam.Service) *Service {
return &Service{
repo: repo,
policyService: policyService,
gatewayService: gatewayService,
ipamService: ipamService,
}
}
func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequest, privateKeyPlaceholder string) (EnrollmentResponse, error) {
selectedGateway, err := s.gatewayService.SelectActive(ctx)
if err != nil {
return EnrollmentResponse{}, err
}
assignedIP, err := s.repo.FindNextAvailableIP(ctx, selectedGateway.ID, selectedGateway.VPNCIDR, 10)
if err != nil {
assignedIP, err = s.ipamService.Allocate(selectedGateway.VPNCIDR, 10)
if err != nil {
return EnrollmentResponse{}, err
}
}
destinations, err := s.policyService.ResolveDestinations(ctx, userID, nil)
if err != nil {
return EnrollmentResponse{}, err
}
if len(destinations) == 0 {
destinations = []string{"172.16.10.0/24"}
}
enrollment, err := s.repo.Enroll(ctx, userID, selectedGateway.ID, input, assignedIP, selectedGateway.DNSServers, destinations)
if err != nil {
return EnrollmentResponse{}, err
}
destinations, err = s.policyService.ResolveDestinations(ctx, userID, &enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
if len(selectedDestinations) == 0 && len(servicesForSelectedProfile(availableProfiles, selectedProfileID)) == 0 {
selectedDestinations = []string{"172.16.10.0/24"}
}
if len(selectedDestinations) == 0 {
selectedDestinations = destinations
}
selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
profileAllowedIPs := mergeProfileAllowedIPs(
append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
selectedGateway.DNSServers,
)
enrollment.Peer = PeerView{
AssignedIP: assignedIP,
DNSServers: selectedGateway.DNSServers,
AllowedIPs: profileAllowedIPs,
Gateway: GatewayView{
ID: selectedGateway.ID,
Name: selectedGateway.Name,
Endpoint: selectedGateway.Endpoint,
PublicKey: selectedGateway.PublicKey,
},
ProfileRevision: 1,
}
enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID
enrollment.Profile = ProfileView{
Format: "wireguard",
Content: profile.BuildWireGuardConfig(profile.BuildInput{
PrivateKey: privateKeyPlaceholder,
Address: assignedIP,
DNSServers: selectedGateway.DNSServers,
ServerPublicKey: selectedGateway.PublicKey,
ServerEndpoint: selectedGateway.Endpoint,
AllowedIPs: profileAllowedIPs,
PersistentKeepal: 25,
}),
}
return enrollment, nil
}
func (s *Service) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
return s.repo.ListByUser(ctx, userID)
}
func (s *Service) ListAll(ctx context.Context) ([]Device, error) {
return s.repo.ListAll(ctx)
}
func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return EnrollmentResponse{}, err
}
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetEnrollmentByDeviceID(ctx, deviceID)
if err != nil {
return EnrollmentResponse{}, err
}
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return EnrollmentResponse{}, err
}
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
var exists bool
for _, profile := range profiles {
if profile.ID == profileID {
exists = true
break
}
}
if !exists {
return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device")
}
if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil {
return EnrollmentResponse{}, err
}
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return ConnectionStatus{
Status: "disconnected",
Resources: []Resource{},
}, nil
}
enrollment, err = s.applyCurrentPolicy(ctx, enrollment)
if err != nil {
return ConnectionStatus{}, err
}
lastSync := "just now"
return ConnectionStatus{
Status: "provisioned",
AssignedIP: enrollment.Peer.AssignedIP,
LastSyncTime: &lastSync,
Resources: enrollment.Resources,
}, nil
}
func (s *Service) Revoke(ctx context.Context, deviceID uuid.UUID) error {
return s.repo.Revoke(ctx, deviceID)
}
func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error {
return s.repo.Rotate(ctx, deviceID)
}
func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
profileAllowedIPs := enrollment.Peer.AllowedIPs
enrollment.Profile = ProfileView{
Format: "wireguard",
Content: profile.BuildWireGuardConfig(profile.BuildInput{
PrivateKey: "__CLIENT_PRIVATE_KEY_REQUIRED__",
Address: enrollment.Peer.AssignedIP,
DNSServers: enrollment.Peer.DNSServers,
ServerPublicKey: enrollment.Peer.Gateway.PublicKey,
ServerEndpoint: enrollment.Peer.Gateway.Endpoint,
AllowedIPs: profileAllowedIPs,
PersistentKeepal: 25,
}),
}
return enrollment
}
func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) {
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
if len(selectedDestinations) == 0 && len(selectedServices) == 0 {
selectedDestinations = []string{"172.16.10.0/24"}
}
enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(
append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
enrollment.Peer.DNSServers,
)
return withDebugProfile(enrollment), nil
}
func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) {
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID)
if err != nil {
return nil, nil, nil, err
}
availableProfiles := make([]AccessProfile, 0, len(profiles))
for _, profile := range profiles {
services := make([]AccessService, 0, len(profile.Services))
for _, service := range profile.Services {
services = append(services, AccessService{
ID: service.ID,
Name: service.Name,
Description: service.Description,
Domain: service.Domain,
UpstreamIP: service.UpstreamIP,
ProxyIP: service.ProxyIP,
Ports: service.Ports,
})
}
availableProfiles = append(availableProfiles, AccessProfile{
ID: profile.ID,
Name: profile.Name,
Description: profile.Description,
FullTunnel: profile.FullTunnel,
Destinations: profile.Destinations,
Services: services,
})
}
if len(availableProfiles) == 0 {
return nil, nil, nil, nil
}
selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID)
if err != nil {
return nil, nil, nil, err
}
for _, profile := range availableProfiles {
if selectedProfileID != nil && profile.ID == *selectedProfileID {
return availableProfiles, selectedProfileID, profile.Destinations, nil
}
}
fallback := availableProfiles[0]
if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil {
return nil, nil, nil, err
}
return availableProfiles, &fallback.ID, fallback.Destinations, nil
}
func resourcesFromDestinations(destinations []string) []Resource {
resources := make([]Resource, 0, len(destinations))
for _, destination := range destinations {
resources = append(resources, Resource{
Type: "cidr",
Value: destination,
Label: destination,
})
}
return resources
}
func resourcesFromProfile(destinations []string, services []AccessService) []Resource {
resources := resourcesFromDestinations(destinations)
for _, service := range services {
resources = append(resources, Resource{
Type: "service",
Value: service.Domain,
Label: service.Name,
Domain: service.Domain,
})
}
return resources
}
func servicesForSelectedProfile(profiles []AccessProfile, selectedProfileID *uuid.UUID) []AccessService {
if selectedProfileID == nil {
if len(profiles) == 0 {
return nil
}
return profiles[0].Services
}
for _, profile := range profiles {
if profile.ID == *selectedProfileID {
return profile.Services
}
}
return nil
}
func proxyRoutesForServices(services []AccessService) []string {
seen := make(map[string]struct{}, len(services))
routes := make([]string, 0, len(services))
for _, service := range services {
route := dnsServerRoute(effectiveServiceProxyIP(service.ProxyIP))
if route == "" {
continue
}
if _, ok := seen[route]; ok {
continue
}
seen[route] = struct{}{}
routes = append(routes, route)
}
return routes
}
func effectiveServiceProxyIP(proxyIP string) string {
override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP"))
if override != "" {
return override
}
return proxyIP
}
func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string {
seen := make(map[string]struct{}, len(destinations)+len(dnsServers))
merged := make([]string, 0, len(destinations)+len(dnsServers))
for _, destination := range destinations {
destination = strings.TrimSpace(destination)
if destination == "" {
continue
}
if _, exists := seen[destination]; exists {
continue
}
seen[destination] = struct{}{}
merged = append(merged, destination)
}
for _, dnsServer := range dnsServers {
route := dnsServerRoute(dnsServer)
if route == "" {
continue
}
if _, exists := seen[route]; exists {
continue
}
seen[route] = struct{}{}
merged = append(merged, route)
}
return merged
}
func dnsServerRoute(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if strings.Contains(value, "/") {
return value
}
return value + "/32"
}