feat: add access profile selection support with device-specific profile persistence

Add SelectOwnProfile handler to allow users to choose from available access profiles. Store selected profile ID per device in settings table with device_access_profile category. Implement GetSelectedProfileID and SetSelectedProfileID repository methods using JSONB storage.

Add ListSelectableProfiles to policy repository and service to query user/group/device-specific profiles ordered by priority. Filter gateway
This commit is contained in:
2026-03-18 12:21:48 +01:00
parent 1ddcbf0b14
commit aaa601a8ba
14 changed files with 549 additions and 43 deletions

View File

@@ -113,6 +113,28 @@ func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) {
apiutil.JSON(w, http.StatusOK, response)
}
func (h *Handler) SelectOwnProfile(w http.ResponseWriter, r *http.Request) {
userID, ok := requestctx.MustUserID(r.Context())
if !ok {
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
return
}
var input SelectProfileRequest
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
return
}
response, err := h.service.SelectProfile(r.Context(), userID, input.ProfileID)
if err != nil {
apiutil.Error(w, http.StatusBadRequest, "profile_selection_failed", err.Error())
return
}
apiutil.JSON(w, http.StatusOK, response)
}
func (h *Handler) GetProfileByDeviceID(w http.ResponseWriter, r *http.Request) {
deviceID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
@@ -18,6 +19,8 @@ type Repository interface {
ListAll(ctx context.Context) ([]Device, error)
GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error)
GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error)
GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error)
SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error
Revoke(ctx context.Context, deviceID uuid.UUID) error
Rotate(ctx context.Context, deviceID uuid.UUID) error
}
@@ -178,6 +181,43 @@ func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uui
return scanEnrollmentRow(row)
}
func (r *PGRepository) GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error) {
row := r.db.QueryRow(ctx, `
select value->>'profile_id'
from settings
where category = 'device_access_profile' and key = $1
`, deviceID.String())
var raw string
if err := row.Scan(&raw); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, err
}
value, err := uuid.Parse(raw)
if err != nil {
return nil, err
}
return &value, nil
}
func (r *PGRepository) SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error {
payload, err := json.Marshal(map[string]string{"profile_id": profileID.String()})
if err != nil {
return err
}
_, err = r.db.Exec(ctx, `
insert into settings (category, key, value, updated_at)
values ('device_access_profile', $1, $2::jsonb, now())
on conflict (category, key)
do update set value = excluded.value, updated_at = now()
`, deviceID.String(), string(payload))
return err
}
func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
rows, err := r.db.Query(ctx, `
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')

View File

@@ -2,6 +2,7 @@ package device
import (
"context"
"fmt"
"os"
"strings"
@@ -63,7 +64,14 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
if len(destinations) == 0 {
destinations = []string{"172.16.10.0/24"}
}
profileAllowedIPs := mergeProfileAllowedIPs(destinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets())
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
if len(selectedDestinations) == 0 {
selectedDestinations = destinations
}
profileAllowedIPs := mergeProfileAllowedIPs(selectedDestinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets())
enrollment.Peer = PeerView{
AssignedIP: assignedIP,
@@ -77,13 +85,9 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
},
ProfileRevision: 1,
}
for _, destination := range destinations {
enrollment.Resources = append(enrollment.Resources, Resource{
Type: "cidr",
Value: destination,
Label: destination,
})
}
enrollment.Resources = resourcesFromDestinations(selectedDestinations)
enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID
enrollment.Profile = ProfileView{
Format: "wireguard",
@@ -125,6 +129,35 @@ func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUI
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return EnrollmentResponse{}, err
}
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
var exists bool
for _, profile := range profiles {
if profile.ID == profileID {
exists = true
break
}
}
if !exists {
return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device")
}
if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil {
return EnrollmentResponse{}, err
}
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
@@ -134,6 +167,11 @@ func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (Co
}, nil
}
enrollment, err = s.applyCurrentPolicy(ctx, enrollment)
if err != nil {
return ConnectionStatus{}, err
}
lastSync := "just now"
return ConnectionStatus{
Status: "provisioned",
@@ -169,24 +207,70 @@ func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
}
func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) {
destinations, err := s.policyService.ResolveDestinations(ctx, enrollment.Device.UserID, &enrollment.Device.ID)
availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
if len(destinations) == 0 {
destinations = []string{"172.16.10.0/24"}
if len(selectedDestinations) == 0 {
selectedDestinations = []string{"172.16.10.0/24"}
}
enrollment.Resources = nil
enrollment.Resources = resourcesFromDestinations(selectedDestinations)
enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(selectedDestinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets())
return withDebugProfile(enrollment), nil
}
func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) {
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID)
if err != nil {
return nil, nil, nil, err
}
availableProfiles := make([]AccessProfile, 0, len(profiles))
for _, profile := range profiles {
availableProfiles = append(availableProfiles, AccessProfile{
ID: profile.ID,
Name: profile.Name,
Description: profile.Description,
FullTunnel: profile.FullTunnel,
Destinations: profile.Destinations,
})
}
if len(availableProfiles) == 0 {
return nil, nil, nil, nil
}
selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID)
if err != nil {
return nil, nil, nil, err
}
for _, profile := range availableProfiles {
if selectedProfileID != nil && profile.ID == *selectedProfileID {
return availableProfiles, selectedProfileID, profile.Destinations, nil
}
}
fallback := availableProfiles[0]
if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil {
return nil, nil, nil, err
}
return availableProfiles, &fallback.ID, fallback.Destinations, nil
}
func resourcesFromDestinations(destinations []string) []Resource {
resources := make([]Resource, 0, len(destinations))
for _, destination := range destinations {
enrollment.Resources = append(enrollment.Resources, Resource{
resources = append(resources, Resource{
Type: "cidr",
Value: destination,
Label: destination,
})
}
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(destinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets())
return withDebugProfile(enrollment), nil
return resources
}
func mergeProfileAllowedIPs(destinations []string, dnsServers []string, webProxyTargets []string) []string {

View File

@@ -37,10 +37,20 @@ type Resource struct {
}
type EnrollmentResponse struct {
Device Device `json:"device"`
Peer PeerView `json:"peer"`
Profile ProfileView `json:"profile"`
Resources []Resource `json:"resources"`
Device Device `json:"device"`
Peer PeerView `json:"peer"`
Profile ProfileView `json:"profile"`
Resources []Resource `json:"resources"`
AvailableProfiles []AccessProfile `json:"available_profiles"`
SelectedProfileID *uuid.UUID `json:"selected_profile_id,omitempty"`
}
type AccessProfile struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
FullTunnel bool `json:"full_tunnel"`
Destinations []string `json:"destinations"`
}
type PeerView struct {
@@ -62,3 +72,7 @@ type ProfileView struct {
Format string `json:"format"`
Content string `json:"content"`
}
type SelectProfileRequest struct {
ProfileID uuid.UUID `json:"profile_id"`
}

View File

@@ -98,13 +98,22 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
from devices d
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
join gateways g on g.id = d.gateway_id
left join settings s on s.category = 'device_access_profile' and s.key = d.id::text
left join group_memberships gm on gm.user_id = d.user_id
left join policy_targets pt on (
(pt.target_type = 'device' and pt.target_id = d.id) or
(pt.target_type = 'user' and pt.target_id = d.user_id) or
(pt.target_type = 'group' and pt.target_id = gm.group_id)
)
left join policy_destinations pd on pd.policy_id = pt.policy_id
left join policies p on p.id = pt.policy_id
and p.deleted_at is null
and p.is_active = true
and p.effect = 'allow'
left join policy_destinations pd on pd.policy_id = p.id
and (
s.value->>'profile_id' is null
or p.id::text = s.value->>'profile_id'
)
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers
`, gatewayID)

View File

@@ -49,6 +49,7 @@ func NewRouter(jwtSecret string, handlers Handlers) http.Handler {
r.Post("/devices/enroll", handlers.Device.Enroll)
r.Get("/me/devices", handlers.Device.ListOwn)
r.Get("/me/profile", handlers.Device.GetOwnProfile)
r.Put("/me/profile-selection", handlers.Device.SelectOwnProfile)
r.Get("/connection/status", handlers.Device.ConnectionStatus)
r.Route("/admin", func(r chi.Router) {

View File

@@ -14,6 +14,7 @@ type Repository interface {
Update(ctx context.Context, policyID uuid.UUID, input UpdateRequest) (Policy, error)
Delete(ctx context.Context, policyID uuid.UUID) error
ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error)
ListSelectableProfiles(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]SelectableProfile, error)
}
type PGRepository struct {
@@ -206,6 +207,53 @@ func (r *PGRepository) ResolveDestinations(ctx context.Context, userID uuid.UUID
return destinations, rows.Err()
}
func (r *PGRepository) ListSelectableProfiles(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]SelectableProfile, error) {
query := `
select
p.id,
p.name,
p.description,
p.full_tunnel,
coalesce(array_agg(pd.destination::text order by pd.destination::text) filter (where pd.destination is not null), '{}')
from policies p
left join policy_destinations pd on pd.policy_id = p.id
join policy_targets pt on pt.policy_id = p.id
where p.deleted_at is null
and p.is_active = true
and p.effect = 'allow'
and (
(pt.target_type = 'user' and pt.target_id = $1)
or (pt.target_type = 'group' and exists (
select 1 from group_memberships gm
where gm.group_id = pt.target_id and gm.user_id = $1
))
`
args := []any{userID}
if deviceID != nil {
query += ` or (pt.target_type = 'device' and pt.target_id = $2)`
args = append(args, *deviceID)
}
query += `)
group by p.id, p.name, p.description, p.full_tunnel, p.priority, p.created_at
order by p.priority asc, p.created_at desc`
rows, err := r.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var profiles []SelectableProfile
for rows.Next() {
var item SelectableProfile
if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.FullTunnel, &item.Destinations); err != nil {
return nil, err
}
profiles = append(profiles, item)
}
return profiles, rows.Err()
}
func (r *PGRepository) getByID(ctx context.Context, policyID uuid.UUID) (Policy, error) {
items, err := r.List(ctx)
if err != nil {

View File

@@ -39,3 +39,7 @@ func (s *Service) Delete(ctx context.Context, policyID uuid.UUID) error {
func (s *Service) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) {
return s.repo.ResolveDestinations(ctx, userID, deviceID)
}
func (s *Service) ListSelectableProfiles(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]SelectableProfile, error) {
return s.repo.ListSelectableProfiles(ctx, userID, deviceID)
}

View File

@@ -40,3 +40,11 @@ type UpdateRequest struct {
Destinations []string `json:"destinations"`
Targets []Target `json:"targets"`
}
type SelectableProfile struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
FullTunnel bool `json:"full_tunnel"`
Destinations []string `json:"destinations"`
}