feat: add access profile selection support with device-specific profile persistence
Add SelectOwnProfile handler to allow users to choose from available access profiles. Store selected profile ID per device in settings table with device_access_profile category. Implement GetSelectedProfileID and SetSelectedProfileID repository methods using JSONB storage. Add ListSelectableProfiles to policy repository and service to query user/group/device-specific profiles ordered by priority. Filter gateway
This commit is contained in:
@@ -113,6 +113,28 @@ func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) {
|
||||
apiutil.JSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) SelectOwnProfile(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := requestctx.MustUserID(r.Context())
|
||||
if !ok {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||
return
|
||||
}
|
||||
|
||||
var input SelectProfileRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.service.SelectProfile(r.Context(), userID, input.ProfileID)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "profile_selection_failed", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) GetProfileByDeviceID(w http.ResponseWriter, r *http.Request) {
|
||||
deviceID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
@@ -18,6 +19,8 @@ type Repository interface {
|
||||
ListAll(ctx context.Context) ([]Device, error)
|
||||
GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error)
|
||||
GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error)
|
||||
GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error)
|
||||
SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error
|
||||
Revoke(ctx context.Context, deviceID uuid.UUID) error
|
||||
Rotate(ctx context.Context, deviceID uuid.UUID) error
|
||||
}
|
||||
@@ -178,6 +181,43 @@ func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uui
|
||||
return scanEnrollmentRow(row)
|
||||
}
|
||||
|
||||
func (r *PGRepository) GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error) {
|
||||
row := r.db.QueryRow(ctx, `
|
||||
select value->>'profile_id'
|
||||
from settings
|
||||
where category = 'device_access_profile' and key = $1
|
||||
`, deviceID.String())
|
||||
|
||||
var raw string
|
||||
if err := row.Scan(&raw); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &value, nil
|
||||
}
|
||||
|
||||
func (r *PGRepository) SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error {
|
||||
payload, err := json.Marshal(map[string]string{"profile_id": profileID.String()})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = r.db.Exec(ctx, `
|
||||
insert into settings (category, key, value, updated_at)
|
||||
values ('device_access_profile', $1, $2::jsonb, now())
|
||||
on conflict (category, key)
|
||||
do update set value = excluded.value, updated_at = now()
|
||||
`, deviceID.String(), string(payload))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
|
||||
rows, err := r.db.Query(ctx, `
|
||||
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
|
||||
|
||||
@@ -2,6 +2,7 @@ package device
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@@ -63,7 +64,14 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
|
||||
if len(destinations) == 0 {
|
||||
destinations = []string{"172.16.10.0/24"}
|
||||
}
|
||||
profileAllowedIPs := mergeProfileAllowedIPs(destinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets())
|
||||
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
if len(selectedDestinations) == 0 {
|
||||
selectedDestinations = destinations
|
||||
}
|
||||
profileAllowedIPs := mergeProfileAllowedIPs(selectedDestinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets())
|
||||
|
||||
enrollment.Peer = PeerView{
|
||||
AssignedIP: assignedIP,
|
||||
@@ -77,13 +85,9 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
|
||||
},
|
||||
ProfileRevision: 1,
|
||||
}
|
||||
for _, destination := range destinations {
|
||||
enrollment.Resources = append(enrollment.Resources, Resource{
|
||||
Type: "cidr",
|
||||
Value: destination,
|
||||
Label: destination,
|
||||
})
|
||||
}
|
||||
enrollment.Resources = resourcesFromDestinations(selectedDestinations)
|
||||
enrollment.AvailableProfiles = availableProfiles
|
||||
enrollment.SelectedProfileID = selectedProfileID
|
||||
|
||||
enrollment.Profile = ProfileView{
|
||||
Format: "wireguard",
|
||||
@@ -125,6 +129,35 @@ func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUI
|
||||
return s.applyCurrentPolicy(ctx, enrollment)
|
||||
}
|
||||
|
||||
func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) {
|
||||
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
var exists bool
|
||||
for _, profile := range profiles {
|
||||
if profile.ID == profileID {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device")
|
||||
}
|
||||
|
||||
if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
return s.applyCurrentPolicy(ctx, enrollment)
|
||||
}
|
||||
|
||||
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
|
||||
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -134,6 +167,11 @@ func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (Co
|
||||
}, nil
|
||||
}
|
||||
|
||||
enrollment, err = s.applyCurrentPolicy(ctx, enrollment)
|
||||
if err != nil {
|
||||
return ConnectionStatus{}, err
|
||||
}
|
||||
|
||||
lastSync := "just now"
|
||||
return ConnectionStatus{
|
||||
Status: "provisioned",
|
||||
@@ -169,24 +207,70 @@ func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
|
||||
}
|
||||
|
||||
func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) {
|
||||
destinations, err := s.policyService.ResolveDestinations(ctx, enrollment.Device.UserID, &enrollment.Device.ID)
|
||||
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
if len(destinations) == 0 {
|
||||
destinations = []string{"172.16.10.0/24"}
|
||||
if len(selectedDestinations) == 0 {
|
||||
selectedDestinations = []string{"172.16.10.0/24"}
|
||||
}
|
||||
|
||||
enrollment.Resources = nil
|
||||
enrollment.Resources = resourcesFromDestinations(selectedDestinations)
|
||||
enrollment.AvailableProfiles = availableProfiles
|
||||
enrollment.SelectedProfileID = selectedProfileID
|
||||
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(selectedDestinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets())
|
||||
return withDebugProfile(enrollment), nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) {
|
||||
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
availableProfiles := make([]AccessProfile, 0, len(profiles))
|
||||
for _, profile := range profiles {
|
||||
availableProfiles = append(availableProfiles, AccessProfile{
|
||||
ID: profile.ID,
|
||||
Name: profile.Name,
|
||||
Description: profile.Description,
|
||||
FullTunnel: profile.FullTunnel,
|
||||
Destinations: profile.Destinations,
|
||||
})
|
||||
}
|
||||
|
||||
if len(availableProfiles) == 0 {
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
|
||||
selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
for _, profile := range availableProfiles {
|
||||
if selectedProfileID != nil && profile.ID == *selectedProfileID {
|
||||
return availableProfiles, selectedProfileID, profile.Destinations, nil
|
||||
}
|
||||
}
|
||||
|
||||
fallback := availableProfiles[0]
|
||||
if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return availableProfiles, &fallback.ID, fallback.Destinations, nil
|
||||
}
|
||||
|
||||
func resourcesFromDestinations(destinations []string) []Resource {
|
||||
resources := make([]Resource, 0, len(destinations))
|
||||
for _, destination := range destinations {
|
||||
enrollment.Resources = append(enrollment.Resources, Resource{
|
||||
resources = append(resources, Resource{
|
||||
Type: "cidr",
|
||||
Value: destination,
|
||||
Label: destination,
|
||||
})
|
||||
}
|
||||
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(destinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets())
|
||||
return withDebugProfile(enrollment), nil
|
||||
return resources
|
||||
}
|
||||
|
||||
func mergeProfileAllowedIPs(destinations []string, dnsServers []string, webProxyTargets []string) []string {
|
||||
|
||||
@@ -37,10 +37,20 @@ type Resource struct {
|
||||
}
|
||||
|
||||
type EnrollmentResponse struct {
|
||||
Device Device `json:"device"`
|
||||
Peer PeerView `json:"peer"`
|
||||
Profile ProfileView `json:"profile"`
|
||||
Resources []Resource `json:"resources"`
|
||||
Device Device `json:"device"`
|
||||
Peer PeerView `json:"peer"`
|
||||
Profile ProfileView `json:"profile"`
|
||||
Resources []Resource `json:"resources"`
|
||||
AvailableProfiles []AccessProfile `json:"available_profiles"`
|
||||
SelectedProfileID *uuid.UUID `json:"selected_profile_id,omitempty"`
|
||||
}
|
||||
|
||||
type AccessProfile struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
FullTunnel bool `json:"full_tunnel"`
|
||||
Destinations []string `json:"destinations"`
|
||||
}
|
||||
|
||||
type PeerView struct {
|
||||
@@ -62,3 +72,7 @@ type ProfileView struct {
|
||||
Format string `json:"format"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type SelectProfileRequest struct {
|
||||
ProfileID uuid.UUID `json:"profile_id"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user