Add SelectOwnProfile handler to allow users to choose from available access profiles. Store selected profile ID per device in settings table with device_access_profile category. Implement GetSelectedProfileID and SetSelectedProfileID repository methods using JSONB storage. Add ListSelectableProfiles to policy repository and service to query user/group/device-specific profiles ordered by priority. Filter gateway
351 lines
9.9 KiB
Go
351 lines
9.9 KiB
Go
package device
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"nexavpn/backend/internal/gateway"
|
|
"nexavpn/backend/internal/ipam"
|
|
"nexavpn/backend/internal/policy"
|
|
"nexavpn/backend/internal/profile"
|
|
)
|
|
|
|
type Service struct {
|
|
repo Repository
|
|
policyService *policy.Service
|
|
gatewayService *gateway.Service
|
|
ipamService *ipam.Service
|
|
}
|
|
|
|
func NewService(repo Repository, policyService *policy.Service, gatewayService *gateway.Service, ipamService *ipam.Service) *Service {
|
|
return &Service{
|
|
repo: repo,
|
|
policyService: policyService,
|
|
gatewayService: gatewayService,
|
|
ipamService: ipamService,
|
|
}
|
|
}
|
|
|
|
func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequest, privateKeyPlaceholder string) (EnrollmentResponse, error) {
|
|
selectedGateway, err := s.gatewayService.SelectActive(ctx)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
assignedIP, err := s.repo.FindNextAvailableIP(ctx, selectedGateway.ID, selectedGateway.VPNCIDR, 10)
|
|
if err != nil {
|
|
assignedIP, err = s.ipamService.Allocate(selectedGateway.VPNCIDR, 10)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
}
|
|
|
|
destinations, err := s.policyService.ResolveDestinations(ctx, userID, nil)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
if len(destinations) == 0 {
|
|
destinations = []string{"172.16.10.0/24"}
|
|
}
|
|
|
|
enrollment, err := s.repo.Enroll(ctx, userID, selectedGateway.ID, input, assignedIP, selectedGateway.DNSServers, destinations)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
destinations, err = s.policyService.ResolveDestinations(ctx, userID, &enrollment.Device.ID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
if len(destinations) == 0 {
|
|
destinations = []string{"172.16.10.0/24"}
|
|
}
|
|
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
if len(selectedDestinations) == 0 {
|
|
selectedDestinations = destinations
|
|
}
|
|
profileAllowedIPs := mergeProfileAllowedIPs(selectedDestinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets())
|
|
|
|
enrollment.Peer = PeerView{
|
|
AssignedIP: assignedIP,
|
|
DNSServers: selectedGateway.DNSServers,
|
|
AllowedIPs: profileAllowedIPs,
|
|
Gateway: GatewayView{
|
|
ID: selectedGateway.ID,
|
|
Name: selectedGateway.Name,
|
|
Endpoint: selectedGateway.Endpoint,
|
|
PublicKey: selectedGateway.PublicKey,
|
|
},
|
|
ProfileRevision: 1,
|
|
}
|
|
enrollment.Resources = resourcesFromDestinations(selectedDestinations)
|
|
enrollment.AvailableProfiles = availableProfiles
|
|
enrollment.SelectedProfileID = selectedProfileID
|
|
|
|
enrollment.Profile = ProfileView{
|
|
Format: "wireguard",
|
|
Content: profile.BuildWireGuardConfig(profile.BuildInput{
|
|
PrivateKey: privateKeyPlaceholder,
|
|
Address: assignedIP,
|
|
DNSServers: selectedGateway.DNSServers,
|
|
ServerPublicKey: selectedGateway.PublicKey,
|
|
ServerEndpoint: selectedGateway.Endpoint,
|
|
AllowedIPs: profileAllowedIPs,
|
|
PersistentKeepal: 25,
|
|
}),
|
|
}
|
|
|
|
return enrollment, nil
|
|
}
|
|
|
|
func (s *Service) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
|
|
return s.repo.ListByUser(ctx, userID)
|
|
}
|
|
|
|
func (s *Service) ListAll(ctx context.Context) ([]Device, error) {
|
|
return s.repo.ListAll(ctx)
|
|
}
|
|
|
|
func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
|
|
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
return s.applyCurrentPolicy(ctx, enrollment)
|
|
}
|
|
|
|
func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
|
|
enrollment, err := s.repo.GetEnrollmentByDeviceID(ctx, deviceID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
return s.applyCurrentPolicy(ctx, enrollment)
|
|
}
|
|
|
|
func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) {
|
|
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
var exists bool
|
|
for _, profile := range profiles {
|
|
if profile.ID == profileID {
|
|
exists = true
|
|
break
|
|
}
|
|
}
|
|
if !exists {
|
|
return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device")
|
|
}
|
|
|
|
if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
|
|
return s.applyCurrentPolicy(ctx, enrollment)
|
|
}
|
|
|
|
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
|
|
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
|
if err != nil {
|
|
return ConnectionStatus{
|
|
Status: "disconnected",
|
|
Resources: []Resource{},
|
|
}, nil
|
|
}
|
|
|
|
enrollment, err = s.applyCurrentPolicy(ctx, enrollment)
|
|
if err != nil {
|
|
return ConnectionStatus{}, err
|
|
}
|
|
|
|
lastSync := "just now"
|
|
return ConnectionStatus{
|
|
Status: "provisioned",
|
|
AssignedIP: enrollment.Peer.AssignedIP,
|
|
LastSyncTime: &lastSync,
|
|
Resources: enrollment.Resources,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) Revoke(ctx context.Context, deviceID uuid.UUID) error {
|
|
return s.repo.Revoke(ctx, deviceID)
|
|
}
|
|
|
|
func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error {
|
|
return s.repo.Rotate(ctx, deviceID)
|
|
}
|
|
|
|
func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
|
|
profileAllowedIPs := enrollment.Peer.AllowedIPs
|
|
enrollment.Profile = ProfileView{
|
|
Format: "wireguard",
|
|
Content: profile.BuildWireGuardConfig(profile.BuildInput{
|
|
PrivateKey: "__CLIENT_PRIVATE_KEY_REQUIRED__",
|
|
Address: enrollment.Peer.AssignedIP,
|
|
DNSServers: enrollment.Peer.DNSServers,
|
|
ServerPublicKey: enrollment.Peer.Gateway.PublicKey,
|
|
ServerEndpoint: enrollment.Peer.Gateway.Endpoint,
|
|
AllowedIPs: profileAllowedIPs,
|
|
PersistentKeepal: 25,
|
|
}),
|
|
}
|
|
return enrollment
|
|
}
|
|
|
|
func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) {
|
|
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID)
|
|
if err != nil {
|
|
return EnrollmentResponse{}, err
|
|
}
|
|
if len(selectedDestinations) == 0 {
|
|
selectedDestinations = []string{"172.16.10.0/24"}
|
|
}
|
|
|
|
enrollment.Resources = resourcesFromDestinations(selectedDestinations)
|
|
enrollment.AvailableProfiles = availableProfiles
|
|
enrollment.SelectedProfileID = selectedProfileID
|
|
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(selectedDestinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets())
|
|
return withDebugProfile(enrollment), nil
|
|
}
|
|
|
|
func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) {
|
|
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
availableProfiles := make([]AccessProfile, 0, len(profiles))
|
|
for _, profile := range profiles {
|
|
availableProfiles = append(availableProfiles, AccessProfile{
|
|
ID: profile.ID,
|
|
Name: profile.Name,
|
|
Description: profile.Description,
|
|
FullTunnel: profile.FullTunnel,
|
|
Destinations: profile.Destinations,
|
|
})
|
|
}
|
|
|
|
if len(availableProfiles) == 0 {
|
|
return nil, nil, nil, nil
|
|
}
|
|
|
|
selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
for _, profile := range availableProfiles {
|
|
if selectedProfileID != nil && profile.ID == *selectedProfileID {
|
|
return availableProfiles, selectedProfileID, profile.Destinations, nil
|
|
}
|
|
}
|
|
|
|
fallback := availableProfiles[0]
|
|
if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
return availableProfiles, &fallback.ID, fallback.Destinations, nil
|
|
}
|
|
|
|
func resourcesFromDestinations(destinations []string) []Resource {
|
|
resources := make([]Resource, 0, len(destinations))
|
|
for _, destination := range destinations {
|
|
resources = append(resources, Resource{
|
|
Type: "cidr",
|
|
Value: destination,
|
|
Label: destination,
|
|
})
|
|
}
|
|
return resources
|
|
}
|
|
|
|
func mergeProfileAllowedIPs(destinations []string, dnsServers []string, webProxyTargets []string) []string {
|
|
seen := make(map[string]struct{}, len(destinations)+len(dnsServers)+len(webProxyTargets))
|
|
merged := make([]string, 0, len(destinations)+len(dnsServers)+len(webProxyTargets))
|
|
|
|
for _, destination := range destinations {
|
|
destination = strings.TrimSpace(destination)
|
|
if destination == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[destination]; exists {
|
|
continue
|
|
}
|
|
seen[destination] = struct{}{}
|
|
merged = append(merged, destination)
|
|
}
|
|
|
|
for _, dnsServer := range dnsServers {
|
|
route := dnsServerRoute(dnsServer)
|
|
if route == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[route]; exists {
|
|
continue
|
|
}
|
|
seen[route] = struct{}{}
|
|
merged = append(merged, route)
|
|
}
|
|
|
|
for _, target := range webProxyTargets {
|
|
route := dnsServerRoute(target)
|
|
if route == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[route]; exists {
|
|
continue
|
|
}
|
|
seen[route] = struct{}{}
|
|
merged = append(merged, route)
|
|
}
|
|
|
|
return merged
|
|
}
|
|
|
|
func dnsServerRoute(value string) string {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return ""
|
|
}
|
|
if strings.Contains(value, "/") {
|
|
return value
|
|
}
|
|
return value + "/32"
|
|
}
|
|
|
|
func alwaysAllowWebProxyTargets() []string {
|
|
raw := os.Getenv("NEXAVPN_ALWAYS_ALLOW_WEB_PROXY_IPS")
|
|
if strings.TrimSpace(raw) == "" {
|
|
return nil
|
|
}
|
|
|
|
seen := make(map[string]struct{})
|
|
targets := make([]string, 0)
|
|
for _, part := range strings.Split(raw, ",") {
|
|
value := strings.TrimSpace(part)
|
|
if value == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[value]; ok {
|
|
continue
|
|
}
|
|
seen[value] = struct{}{}
|
|
targets = append(targets, value)
|
|
}
|
|
return targets
|
|
}
|