feat: add access profile selection support with device-specific profile persistence

Add SelectOwnProfile handler to allow users to choose from available access profiles. Store selected profile ID per device in settings table with device_access_profile category. Implement GetSelectedProfileID and SetSelectedProfileID repository methods using JSONB storage.

Add ListSelectableProfiles to policy repository and service to query user/group/device-specific profiles ordered by priority. Filter gateway
This commit is contained in:
2026-03-18 12:21:48 +01:00
parent 1ddcbf0b14
commit aaa601a8ba
14 changed files with 549 additions and 43 deletions

View File

@@ -113,6 +113,28 @@ func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) {
apiutil.JSON(w, http.StatusOK, response) apiutil.JSON(w, http.StatusOK, response)
} }
func (h *Handler) SelectOwnProfile(w http.ResponseWriter, r *http.Request) {
userID, ok := requestctx.MustUserID(r.Context())
if !ok {
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
return
}
var input SelectProfileRequest
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
return
}
response, err := h.service.SelectProfile(r.Context(), userID, input.ProfileID)
if err != nil {
apiutil.Error(w, http.StatusBadRequest, "profile_selection_failed", err.Error())
return
}
apiutil.JSON(w, http.StatusOK, response)
}
func (h *Handler) GetProfileByDeviceID(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetProfileByDeviceID(w http.ResponseWriter, r *http.Request) {
deviceID, err := uuid.Parse(chi.URLParam(r, "id")) deviceID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil { if err != nil {

View File

@@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
) )
@@ -18,6 +19,8 @@ type Repository interface {
ListAll(ctx context.Context) ([]Device, error) ListAll(ctx context.Context) ([]Device, error)
GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error)
GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error)
GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error)
SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error
Revoke(ctx context.Context, deviceID uuid.UUID) error Revoke(ctx context.Context, deviceID uuid.UUID) error
Rotate(ctx context.Context, deviceID uuid.UUID) error Rotate(ctx context.Context, deviceID uuid.UUID) error
} }
@@ -178,6 +181,43 @@ func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uui
return scanEnrollmentRow(row) return scanEnrollmentRow(row)
} }
func (r *PGRepository) GetSelectedProfileID(ctx context.Context, deviceID uuid.UUID) (*uuid.UUID, error) {
row := r.db.QueryRow(ctx, `
select value->>'profile_id'
from settings
where category = 'device_access_profile' and key = $1
`, deviceID.String())
var raw string
if err := row.Scan(&raw); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, err
}
value, err := uuid.Parse(raw)
if err != nil {
return nil, err
}
return &value, nil
}
func (r *PGRepository) SetSelectedProfileID(ctx context.Context, deviceID uuid.UUID, profileID uuid.UUID) error {
payload, err := json.Marshal(map[string]string{"profile_id": profileID.String()})
if err != nil {
return err
}
_, err = r.db.Exec(ctx, `
insert into settings (category, key, value, updated_at)
values ('device_access_profile', $1, $2::jsonb, now())
on conflict (category, key)
do update set value = excluded.value, updated_at = now()
`, deviceID.String(), string(payload))
return err
}
func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) { func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
rows, err := r.db.Query(ctx, ` rows, err := r.db.Query(ctx, `
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '') select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')

View File

@@ -2,6 +2,7 @@ package device
import ( import (
"context" "context"
"fmt"
"os" "os"
"strings" "strings"
@@ -63,7 +64,14 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
if len(destinations) == 0 { if len(destinations) == 0 {
destinations = []string{"172.16.10.0/24"} destinations = []string{"172.16.10.0/24"}
} }
profileAllowedIPs := mergeProfileAllowedIPs(destinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets()) availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, userID, enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
if len(selectedDestinations) == 0 {
selectedDestinations = destinations
}
profileAllowedIPs := mergeProfileAllowedIPs(selectedDestinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets())
enrollment.Peer = PeerView{ enrollment.Peer = PeerView{
AssignedIP: assignedIP, AssignedIP: assignedIP,
@@ -77,13 +85,9 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
}, },
ProfileRevision: 1, ProfileRevision: 1,
} }
for _, destination := range destinations { enrollment.Resources = resourcesFromDestinations(selectedDestinations)
enrollment.Resources = append(enrollment.Resources, Resource{ enrollment.AvailableProfiles = availableProfiles
Type: "cidr", enrollment.SelectedProfileID = selectedProfileID
Value: destination,
Label: destination,
})
}
enrollment.Profile = ProfileView{ enrollment.Profile = ProfileView{
Format: "wireguard", Format: "wireguard",
@@ -125,6 +129,35 @@ func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUI
return s.applyCurrentPolicy(ctx, enrollment) return s.applyCurrentPolicy(ctx, enrollment)
} }
func (s *Service) SelectProfile(ctx context.Context, userID uuid.UUID, profileID uuid.UUID) (EnrollmentResponse, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil {
return EnrollmentResponse{}, err
}
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &enrollment.Device.ID)
if err != nil {
return EnrollmentResponse{}, err
}
var exists bool
for _, profile := range profiles {
if profile.ID == profileID {
exists = true
break
}
}
if !exists {
return EnrollmentResponse{}, fmt.Errorf("selected access profile is not available for this device")
}
if err := s.repo.SetSelectedProfileID(ctx, enrollment.Device.ID, profileID); err != nil {
return EnrollmentResponse{}, err
}
return s.applyCurrentPolicy(ctx, enrollment)
}
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) { func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID) enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
if err != nil { if err != nil {
@@ -134,6 +167,11 @@ func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (Co
}, nil }, nil
} }
enrollment, err = s.applyCurrentPolicy(ctx, enrollment)
if err != nil {
return ConnectionStatus{}, err
}
lastSync := "just now" lastSync := "just now"
return ConnectionStatus{ return ConnectionStatus{
Status: "provisioned", Status: "provisioned",
@@ -169,24 +207,70 @@ func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse {
} }
func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) { func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) {
destinations, err := s.policyService.ResolveDestinations(ctx, enrollment.Device.UserID, &enrollment.Device.ID) availableProfiles, selectedProfileID, selectedDestinations, err := s.resolveAccessProfiles(ctx, enrollment.Device.UserID, enrollment.Device.ID)
if err != nil { if err != nil {
return EnrollmentResponse{}, err return EnrollmentResponse{}, err
} }
if len(destinations) == 0 { if len(selectedDestinations) == 0 {
destinations = []string{"172.16.10.0/24"} selectedDestinations = []string{"172.16.10.0/24"}
} }
enrollment.Resources = nil enrollment.Resources = resourcesFromDestinations(selectedDestinations)
enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(selectedDestinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets())
return withDebugProfile(enrollment), nil
}
func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, deviceID uuid.UUID) ([]AccessProfile, *uuid.UUID, []string, error) {
profiles, err := s.policyService.ListSelectableProfiles(ctx, userID, &deviceID)
if err != nil {
return nil, nil, nil, err
}
availableProfiles := make([]AccessProfile, 0, len(profiles))
for _, profile := range profiles {
availableProfiles = append(availableProfiles, AccessProfile{
ID: profile.ID,
Name: profile.Name,
Description: profile.Description,
FullTunnel: profile.FullTunnel,
Destinations: profile.Destinations,
})
}
if len(availableProfiles) == 0 {
return nil, nil, nil, nil
}
selectedProfileID, err := s.repo.GetSelectedProfileID(ctx, deviceID)
if err != nil {
return nil, nil, nil, err
}
for _, profile := range availableProfiles {
if selectedProfileID != nil && profile.ID == *selectedProfileID {
return availableProfiles, selectedProfileID, profile.Destinations, nil
}
}
fallback := availableProfiles[0]
if err := s.repo.SetSelectedProfileID(ctx, deviceID, fallback.ID); err != nil {
return nil, nil, nil, err
}
return availableProfiles, &fallback.ID, fallback.Destinations, nil
}
func resourcesFromDestinations(destinations []string) []Resource {
resources := make([]Resource, 0, len(destinations))
for _, destination := range destinations { for _, destination := range destinations {
enrollment.Resources = append(enrollment.Resources, Resource{ resources = append(resources, Resource{
Type: "cidr", Type: "cidr",
Value: destination, Value: destination,
Label: destination, Label: destination,
}) })
} }
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(destinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets()) return resources
return withDebugProfile(enrollment), nil
} }
func mergeProfileAllowedIPs(destinations []string, dnsServers []string, webProxyTargets []string) []string { func mergeProfileAllowedIPs(destinations []string, dnsServers []string, webProxyTargets []string) []string {

View File

@@ -41,6 +41,16 @@ type EnrollmentResponse struct {
Peer PeerView `json:"peer"` Peer PeerView `json:"peer"`
Profile ProfileView `json:"profile"` Profile ProfileView `json:"profile"`
Resources []Resource `json:"resources"` Resources []Resource `json:"resources"`
AvailableProfiles []AccessProfile `json:"available_profiles"`
SelectedProfileID *uuid.UUID `json:"selected_profile_id,omitempty"`
}
type AccessProfile struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
FullTunnel bool `json:"full_tunnel"`
Destinations []string `json:"destinations"`
} }
type PeerView struct { type PeerView struct {
@@ -62,3 +72,7 @@ type ProfileView struct {
Format string `json:"format"` Format string `json:"format"`
Content string `json:"content"` Content string `json:"content"`
} }
type SelectProfileRequest struct {
ProfileID uuid.UUID `json:"profile_id"`
}

View File

@@ -98,13 +98,22 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
from devices d from devices d
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
join gateways g on g.id = d.gateway_id join gateways g on g.id = d.gateway_id
left join settings s on s.category = 'device_access_profile' and s.key = d.id::text
left join group_memberships gm on gm.user_id = d.user_id left join group_memberships gm on gm.user_id = d.user_id
left join policy_targets pt on ( left join policy_targets pt on (
(pt.target_type = 'device' and pt.target_id = d.id) or (pt.target_type = 'device' and pt.target_id = d.id) or
(pt.target_type = 'user' and pt.target_id = d.user_id) or (pt.target_type = 'user' and pt.target_id = d.user_id) or
(pt.target_type = 'group' and pt.target_id = gm.group_id) (pt.target_type = 'group' and pt.target_id = gm.group_id)
) )
left join policy_destinations pd on pd.policy_id = pt.policy_id left join policies p on p.id = pt.policy_id
and p.deleted_at is null
and p.is_active = true
and p.effect = 'allow'
left join policy_destinations pd on pd.policy_id = p.id
and (
s.value->>'profile_id' is null
or p.id::text = s.value->>'profile_id'
)
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active' where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers
`, gatewayID) `, gatewayID)

View File

@@ -49,6 +49,7 @@ func NewRouter(jwtSecret string, handlers Handlers) http.Handler {
r.Post("/devices/enroll", handlers.Device.Enroll) r.Post("/devices/enroll", handlers.Device.Enroll)
r.Get("/me/devices", handlers.Device.ListOwn) r.Get("/me/devices", handlers.Device.ListOwn)
r.Get("/me/profile", handlers.Device.GetOwnProfile) r.Get("/me/profile", handlers.Device.GetOwnProfile)
r.Put("/me/profile-selection", handlers.Device.SelectOwnProfile)
r.Get("/connection/status", handlers.Device.ConnectionStatus) r.Get("/connection/status", handlers.Device.ConnectionStatus)
r.Route("/admin", func(r chi.Router) { r.Route("/admin", func(r chi.Router) {

View File

@@ -14,6 +14,7 @@ type Repository interface {
Update(ctx context.Context, policyID uuid.UUID, input UpdateRequest) (Policy, error) Update(ctx context.Context, policyID uuid.UUID, input UpdateRequest) (Policy, error)
Delete(ctx context.Context, policyID uuid.UUID) error Delete(ctx context.Context, policyID uuid.UUID) error
ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error)
ListSelectableProfiles(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]SelectableProfile, error)
} }
type PGRepository struct { type PGRepository struct {
@@ -206,6 +207,53 @@ func (r *PGRepository) ResolveDestinations(ctx context.Context, userID uuid.UUID
return destinations, rows.Err() return destinations, rows.Err()
} }
func (r *PGRepository) ListSelectableProfiles(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]SelectableProfile, error) {
query := `
select
p.id,
p.name,
p.description,
p.full_tunnel,
coalesce(array_agg(pd.destination::text order by pd.destination::text) filter (where pd.destination is not null), '{}')
from policies p
left join policy_destinations pd on pd.policy_id = p.id
join policy_targets pt on pt.policy_id = p.id
where p.deleted_at is null
and p.is_active = true
and p.effect = 'allow'
and (
(pt.target_type = 'user' and pt.target_id = $1)
or (pt.target_type = 'group' and exists (
select 1 from group_memberships gm
where gm.group_id = pt.target_id and gm.user_id = $1
))
`
args := []any{userID}
if deviceID != nil {
query += ` or (pt.target_type = 'device' and pt.target_id = $2)`
args = append(args, *deviceID)
}
query += `)
group by p.id, p.name, p.description, p.full_tunnel, p.priority, p.created_at
order by p.priority asc, p.created_at desc`
rows, err := r.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var profiles []SelectableProfile
for rows.Next() {
var item SelectableProfile
if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.FullTunnel, &item.Destinations); err != nil {
return nil, err
}
profiles = append(profiles, item)
}
return profiles, rows.Err()
}
func (r *PGRepository) getByID(ctx context.Context, policyID uuid.UUID) (Policy, error) { func (r *PGRepository) getByID(ctx context.Context, policyID uuid.UUID) (Policy, error) {
items, err := r.List(ctx) items, err := r.List(ctx)
if err != nil { if err != nil {

View File

@@ -39,3 +39,7 @@ func (s *Service) Delete(ctx context.Context, policyID uuid.UUID) error {
func (s *Service) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) { func (s *Service) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) {
return s.repo.ResolveDestinations(ctx, userID, deviceID) return s.repo.ResolveDestinations(ctx, userID, deviceID)
} }
func (s *Service) ListSelectableProfiles(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]SelectableProfile, error) {
return s.repo.ListSelectableProfiles(ctx, userID, deviceID)
}

View File

@@ -40,3 +40,11 @@ type UpdateRequest struct {
Destinations []string `json:"destinations"` Destinations []string `json:"destinations"`
Targets []Target `json:"targets"` Targets []Target `json:"targets"`
} }
type SelectableProfile struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
FullTunnel bool `json:"full_tunnel"`
Destinations []string `json:"destinations"`
}

View File

@@ -58,7 +58,12 @@ struct EnrollmentPayload {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct EnrollmentResult { struct EnrollmentResult {
assigned_ip: String, assigned_ip: String,
#[serde(default)]
resources: Vec<String>, resources: Vec<String>,
#[serde(default)]
available_profiles: Vec<AccessProfile>,
#[serde(default)]
selected_profile_id: Option<String>,
profile_revision: u32, profile_revision: u32,
gateway_endpoint: String, gateway_endpoint: String,
profile_path: String, profile_path: String,
@@ -66,6 +71,16 @@ struct EnrollmentResult {
tunnel_strategy: String, tunnel_strategy: String,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct AccessProfile {
id: String,
name: String,
description: String,
full_tunnel: bool,
destinations: Vec<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct TunnelMetrics { struct TunnelMetrics {
@@ -113,6 +128,8 @@ struct EnrollResponse {
peer: PeerView, peer: PeerView,
profile: ProfileView, profile: ProfileView,
resources: Vec<ResourceView>, resources: Vec<ResourceView>,
available_profiles: Vec<AccessProfileView>,
selected_profile_id: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@@ -139,11 +156,25 @@ struct ResourceView {
value: String, value: String,
} }
#[derive(Debug, Deserialize)]
struct AccessProfileView {
id: String,
name: String,
description: String,
full_tunnel: bool,
destinations: Vec<String>,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ProfileView { struct ProfileView {
content: String, content: String,
} }
#[derive(Debug, Serialize)]
struct SelectProfilePayload<'a> {
profile_id: &'a str,
}
#[tauri::command] #[tauri::command]
async fn enroll_device( async fn enroll_device(
app: AppHandle, app: AppHandle,
@@ -213,6 +244,8 @@ async fn enroll_device(
let result = EnrollmentResult { let result = EnrollmentResult {
assigned_ip: enroll.peer.assigned_ip, assigned_ip: enroll.peer.assigned_ip,
resources: enroll.resources.into_iter().map(|resource| resource.value).collect(), resources: enroll.resources.into_iter().map(|resource| resource.value).collect(),
available_profiles: map_access_profiles(enroll.available_profiles),
selected_profile_id: enroll.selected_profile_id,
profile_revision: enroll.peer.profile_revision, profile_revision: enroll.peer.profile_revision,
gateway_endpoint: enroll.peer.gateway.endpoint, gateway_endpoint: enroll.peer.gateway.endpoint,
profile_path: profile_path.display().to_string(), profile_path: profile_path.display().to_string(),
@@ -275,6 +308,88 @@ async fn sync_profile(app: AppHandle, _state: State<'_, AppState>) -> Result<Enr
Ok(session_state.enrollment) Ok(session_state.enrollment)
} }
#[tauri::command]
async fn select_access_profile(app: AppHandle, profile_id: String) -> Result<EnrollmentResult, String> {
let mut existing = {
let state = app.state::<AppState>();
let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?;
session.clone().ok_or_else(|| "No enrolled profile is available yet".to_string())?
};
let client = Client::builder()
.use_rustls_tls()
.build()
.map_err(|err| err.to_string())?;
let mut response = client
.put(format!(
"{}/api/v1/me/profile-selection",
existing.server_url.trim_end_matches('/')
))
.bearer_auth(&existing.access_token)
.json(&SelectProfilePayload {
profile_id: &profile_id,
})
.send()
.await
.map_err(|err| format!("Profile selection failed: {}", err))?;
if response.status().as_u16() == 401 {
let refresh = client
.post(format!("{}/api/v1/auth/refresh", existing.server_url.trim_end_matches('/')))
.json(&RefreshRequest {
refresh_token: &existing.refresh_token,
})
.send()
.await
.map_err(|err| format!("Session refresh failed: {}", err))?;
if !refresh.status().is_success() {
let status = refresh.status();
let body = refresh
.text()
.await
.unwrap_or_else(|_| "<unable to read response body>".into());
return Err(format!("Session refresh failed with status {}: {}", status, body));
}
let refreshed = refresh
.json::<LoginResponse>()
.await
.map_err(|err| format!("Unable to decode refresh response: {}", err))?;
existing.access_token = refreshed.access_token;
existing.refresh_token = refreshed.refresh_token;
response = client
.put(format!(
"{}/api/v1/me/profile-selection",
existing.server_url.trim_end_matches('/')
))
.bearer_auth(&existing.access_token)
.json(&SelectProfilePayload {
profile_id: &profile_id,
})
.send()
.await
.map_err(|err| format!("Profile selection failed: {}", err))?;
}
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|_| "<unable to read response body>".into());
return Err(format!("Profile selection failed with status {}: {}", status, body));
}
let _ = response;
existing = sync_current_session(&app).await?;
refresh_tray_menu(&app);
Ok(existing.enrollment)
}
#[tauri::command] #[tauri::command]
async fn connect_tunnel(app: AppHandle) -> Result<EnrollmentResult, String> { async fn connect_tunnel(app: AppHandle) -> Result<EnrollmentResult, String> {
let session_state = sync_current_session(&app).await?; let session_state = sync_current_session(&app).await?;
@@ -346,6 +461,19 @@ fn materialize_profile(profile_content: &str, private_key: &str) -> String {
.replace("__CLIENT_PRIVATE_KEY_REQUIRED__", private_key) .replace("__CLIENT_PRIVATE_KEY_REQUIRED__", private_key)
} }
fn map_access_profiles(items: Vec<AccessProfileView>) -> Vec<AccessProfile> {
items
.into_iter()
.map(|item| AccessProfile {
id: item.id,
name: item.name,
description: item.description,
full_tunnel: item.full_tunnel,
destinations: item.destinations,
})
.collect()
}
fn write_profile(app: &AppHandle, profile_content: &str) -> Result<PathBuf, String> { fn write_profile(app: &AppHandle, profile_content: &str) -> Result<PathBuf, String> {
let app_dir = ensure_app_dir(app)?; let app_dir = ensure_app_dir(app)?;
let profile_path = app_dir.join(format!("{}.conf", PROFILE_NAME)); let profile_path = app_dir.join(format!("{}.conf", PROFILE_NAME));
@@ -499,6 +627,8 @@ async fn sync_current_session(app: &AppHandle) -> Result<SessionState, String> {
let result = EnrollmentResult { let result = EnrollmentResult {
assigned_ip: enroll.peer.assigned_ip, assigned_ip: enroll.peer.assigned_ip,
resources: enroll.resources.into_iter().map(|resource| resource.value).collect(), resources: enroll.resources.into_iter().map(|resource| resource.value).collect(),
available_profiles: map_access_profiles(enroll.available_profiles),
selected_profile_id: enroll.selected_profile_id,
profile_revision: enroll.peer.profile_revision, profile_revision: enroll.peer.profile_revision,
gateway_endpoint: enroll.peer.gateway.endpoint, gateway_endpoint: enroll.peer.gateway.endpoint,
profile_path: profile_path.display().to_string(), profile_path: profile_path.display().to_string(),
@@ -805,7 +935,17 @@ pub fn run() {
} }
_ => {} _ => {}
}) })
.invoke_handler(tauri::generate_handler![load_state, clear_session, enroll_device, sync_profile, connect_tunnel, disconnect_tunnel, tunnel_status, tunnel_metrics]) .invoke_handler(tauri::generate_handler![
load_state,
clear_session,
enroll_device,
sync_profile,
select_access_profile,
connect_tunnel,
disconnect_tunnel,
tunnel_status,
tunnel_metrics
])
.run(tauri::generate_context!()) .run(tauri::generate_context!())
.expect("error while running tauri application"); .expect("error while running tauri application");
} }

View File

@@ -4,9 +4,19 @@ import { AppHeader } from "./components/AppHeader";
import { ResourcePanel } from "./components/ResourcePanel"; import { ResourcePanel } from "./components/ResourcePanel";
import { StatusCard } from "./components/StatusCard"; import { StatusCard } from "./components/StatusCard";
type AccessProfile = {
id: string;
name: string;
description: string;
fullTunnel: boolean;
destinations: string[];
};
type EnrollmentState = { type EnrollmentState = {
assignedIp: string; assignedIp: string;
resources: string[]; resources: string[];
availableProfiles: AccessProfile[];
selectedProfileId: string | null;
profileRevision: number; profileRevision: number;
gatewayEndpoint: string; gatewayEndpoint: string;
profilePath: string; profilePath: string;
@@ -36,6 +46,11 @@ function currentProfileLabel(state: EnrollmentState | null) {
return "Not provisioned"; return "Not provisioned";
} }
const selectedProfile = state.availableProfiles.find((profile) => profile.id === state.selectedProfileId);
if (selectedProfile) {
return selectedProfile.name;
}
if (state.resources.includes("0.0.0.0/0")) { if (state.resources.includes("0.0.0.0/0")) {
return "Full tunnel"; return "Full tunnel";
} }
@@ -53,6 +68,7 @@ export function App() {
const [password, setPassword] = useState(""); const [password, setPassword] = useState("");
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [syncing, setSyncing] = useState(false); const [syncing, setSyncing] = useState(false);
const [selectingProfile, setSelectingProfile] = useState(false);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [connected, setConnected] = useState(false); const [connected, setConnected] = useState(false);
const [state, setState] = useState<EnrollmentState | null>(null); const [state, setState] = useState<EnrollmentState | null>(null);
@@ -150,6 +166,25 @@ export function App() {
} }
} }
async function onSelectProfile(profileId: string) {
if (!state || profileId === state.selectedProfileId) {
return;
}
setSelectingProfile(true);
setError(null);
try {
const result = await invoke<EnrollmentState>("select_access_profile", { profileId });
setState(result);
await refreshTunnelStatus();
} catch (err) {
setError(formatInvokeError(err, "Profile selection failed"));
} finally {
setSelectingProfile(false);
}
}
async function toggleConnection() { async function toggleConnection() {
const command = connected ? "disconnect_tunnel" : "connect_tunnel"; const command = connected ? "disconnect_tunnel" : "connect_tunnel";
try { try {
@@ -248,9 +283,14 @@ export function App() {
)} )}
<ResourcePanel <ResourcePanel
connected={connected}
onReset={resetEnrollment} onReset={resetEnrollment}
onSelectProfile={onSelectProfile}
profileLabel={profileLabel} profileLabel={profileLabel}
profiles={state?.availableProfiles ?? []}
resources={state?.resources ?? []} resources={state?.resources ?? []}
selectedProfileId={state?.selectedProfileId ?? null}
selectingProfile={selectingProfile}
/> />
</div> </div>
</div> </div>

View File

@@ -33,6 +33,7 @@ export function AppHeader({
</div> </div>
<div className="header-actions"> <div className="header-actions">
<div className="header-actions-secondary">
{enrolled ? ( {enrolled ? (
<> <>
<ActionButton disabled={syncing} onClick={onSync} variant="secondary"> <ActionButton disabled={syncing} onClick={onSync} variant="secondary">
@@ -43,6 +44,8 @@ export function AppHeader({
</ActionButton> </ActionButton>
</> </>
) : null} ) : null}
</div>
<div className="header-actions-primary">
<ActionButton <ActionButton
disabled={!enrolled} disabled={!enrolled}
onClick={onToggleConnection} onClick={onToggleConnection}
@@ -51,6 +54,7 @@ export function AppHeader({
{!enrolled ? "Provision first" : connected ? "Disconnect" : "Connect"} {!enrolled ? "Provision first" : connected ? "Disconnect" : "Connect"}
</ActionButton> </ActionButton>
</div> </div>
</div>
</header> </header>
); );
} }

View File

@@ -10,13 +10,34 @@ function ResourceListItem({ value }: { value: string }) {
} }
type ResourcePanelProps = { type ResourcePanelProps = {
connected: boolean;
profiles: Array<{
id: string;
name: string;
description: string;
destinations: string[];
}>;
resources: string[]; resources: string[];
profileLabel: string; profileLabel: string;
selectedProfileId: string | null;
selectingProfile: boolean;
onSelectProfile: (profileId: string) => void;
onReset: () => void; onReset: () => void;
}; };
export function ResourcePanel({ resources, profileLabel, onReset }: ResourcePanelProps) { export function ResourcePanel({
connected,
profiles,
resources,
profileLabel,
selectedProfileId,
selectingProfile,
onSelectProfile,
onReset
}: ResourcePanelProps) {
const effectiveResources = resources.length > 0 ? resources : ["Keine Ressourcen zugewiesen"]; const effectiveResources = resources.length > 0 ? resources : ["Keine Ressourcen zugewiesen"];
const showSelector = profiles.length > 1;
const selectedProfile = profiles.find((profile) => profile.id === selectedProfileId) ?? null;
return ( return (
<aside className="resource-panel"> <aside className="resource-panel">
@@ -31,8 +52,29 @@ export function ResourcePanel({ resources, profileLabel, onReset }: ResourcePane
<div className="resource-meta"> <div className="resource-meta">
<span className="resource-meta-label">Zugriffsprofil</span> <span className="resource-meta-label">Zugriffsprofil</span>
<strong>{profileLabel}</strong> <strong>{profileLabel}</strong>
{selectedProfile?.description ? <small>{selectedProfile.description}</small> : null}
</div> </div>
{showSelector ? (
<label className="resource-selector">
<span className="resource-meta-label">Ressource auswählen</span>
<select
disabled={connected || selectingProfile}
onChange={(event) => onSelectProfile(event.target.value)}
value={selectedProfileId ?? profiles[0]?.id ?? ""}
>
{profiles.map((profile) => (
<option key={profile.id} value={profile.id}>
{profile.name}
</option>
))}
</select>
<small>
{connected ? "Auswahl ist nur getrennt möglich." : selectingProfile ? "Profil wird aktualisiert..." : "Auswahl wird vor dem Verbinden in die Config übernommen."}
</small>
</label>
) : null}
<ul className="resource-list"> <ul className="resource-list">
{effectiveResources.map((resource) => ( {effectiveResources.map((resource) => (
<ResourceListItem key={resource} value={resource} /> <ResourceListItem key={resource} value={resource} />

View File

@@ -69,6 +69,7 @@ input {
align-items: center; align-items: center;
gap: 18px; gap: 18px;
min-width: 0; min-width: 0;
flex: 1 1 auto;
} }
.brand-icon-shell { .brand-icon-shell {
@@ -119,14 +120,23 @@ input {
margin: 0; margin: 0;
color: #9fb3d4; color: #9fb3d4;
font-size: 0.98rem; font-size: 0.98rem;
max-width: 620px; max-width: 520px;
} }
.header-actions { .header-actions {
display: flex; display: flex;
align-items: center; align-items: center;
gap: 10px; gap: 10px;
flex-wrap: wrap; flex-wrap: nowrap;
flex: none;
}
.header-actions-secondary,
.header-actions-primary {
display: flex;
align-items: center;
gap: 10px;
flex-wrap: nowrap;
} }
.action-button { .action-button {
@@ -135,6 +145,7 @@ input {
min-height: 44px; min-height: 44px;
padding: 0 18px; padding: 0 18px;
font-weight: 700; font-weight: 700;
white-space: nowrap;
cursor: pointer; cursor: pointer;
transition: transition:
background-color 180ms ease, background-color 180ms ease,
@@ -382,7 +393,8 @@ input {
} }
.resource-panel { .resource-panel {
grid-template-rows: auto auto 1fr auto; grid-auto-rows: min-content;
align-content: start;
} }
.resource-count { .resource-count {
@@ -408,6 +420,40 @@ input {
border: 1px solid rgba(154, 181, 228, 0.08); border: 1px solid rgba(154, 181, 228, 0.08);
} }
.resource-meta small,
.resource-selector small {
color: #91a8cc;
font-size: 0.82rem;
line-height: 1.35;
}
.resource-selector {
display: grid;
gap: 8px;
}
.resource-selector select {
width: 100%;
border: 1px solid rgba(177, 197, 229, 0.16);
background: rgba(7, 14, 27, 0.9);
color: #f5f7fb;
border-radius: 16px;
padding: 12px 14px;
outline: none;
transition: border-color 180ms ease, background-color 180ms ease;
}
.resource-selector select:hover:not(:disabled),
.resource-selector select:focus-visible {
border-color: rgba(118, 218, 200, 0.28);
background: rgba(10, 18, 32, 0.96);
}
.resource-selector select:disabled {
opacity: 0.7;
cursor: default;
}
.resource-meta-label { .resource-meta-label {
color: #8fa6ca; color: #8fa6ca;
font-size: 0.75rem; font-size: 0.75rem;
@@ -506,7 +552,7 @@ input {
} }
.brand-text h1 { .brand-text h1 {
font-size: 1.8rem; font-size: 1.7rem;
} }
.body-grid { .body-grid {
@@ -539,12 +585,16 @@ input {
} }
.header-actions, .header-actions,
.header-actions-secondary,
.header-actions-primary,
.resource-footer, .resource-footer,
.login-actions { .login-actions {
width: 100%; width: 100%;
} }
.header-actions .action-button, .header-actions .action-button,
.header-actions-secondary .action-button,
.header-actions-primary .action-button,
.resource-footer .action-button, .resource-footer .action-button,
.login-actions .action-button { .login-actions .action-button {
width: 100%; width: 100%;