diff --git a/backend/internal/device/service.go b/backend/internal/device/service.go index da60f24..5818032 100644 --- a/backend/internal/device/service.go +++ b/backend/internal/device/service.go @@ -113,7 +113,7 @@ func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUI if err != nil { return EnrollmentResponse{}, err } - return withDebugProfile(enrollment), nil + return s.applyCurrentPolicy(ctx, enrollment) } func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) { @@ -121,7 +121,7 @@ func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUI if err != nil { return EnrollmentResponse{}, err } - return withDebugProfile(enrollment), nil + return s.applyCurrentPolicy(ctx, enrollment) } func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) { @@ -151,7 +151,7 @@ func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error { } func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse { - profileAllowedIPs := mergeProfileAllowedIPs(enrollment.Peer.AllowedIPs, enrollment.Peer.DNSServers) + profileAllowedIPs := enrollment.Peer.AllowedIPs enrollment.Profile = ProfileView{ Format: "wireguard", Content: profile.BuildWireGuardConfig(profile.BuildInput{ @@ -167,6 +167,27 @@ func withDebugProfile(enrollment EnrollmentResponse) EnrollmentResponse { return enrollment } +func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentResponse) (EnrollmentResponse, error) { + destinations, err := s.policyService.ResolveDestinations(ctx, enrollment.Device.UserID, &enrollment.Device.ID) + if err != nil { + return EnrollmentResponse{}, err + } + if len(destinations) == 0 { + destinations = []string{"172.16.10.0/24"} + } + + enrollment.Resources = nil + for _, destination := range destinations { + enrollment.Resources = append(enrollment.Resources, Resource{ + Type: "cidr", + Value: destination, + Label: destination, + }) + } + enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(destinations, enrollment.Peer.DNSServers) + return withDebugProfile(enrollment), nil +} + func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string { seen := make(map[string]struct{}, len(destinations)+len(dnsServers)) merged := make([]string, 0, len(destinations)+len(dnsServers)) diff --git a/desktop-client/src-tauri/src/lib.rs b/desktop-client/src-tauri/src/lib.rs index 3f967d5..fe95b0c 100644 --- a/desktop-client/src-tauri/src/lib.rs +++ b/desktop-client/src-tauri/src/lib.rs @@ -260,77 +260,18 @@ fn clear_session(app: AppHandle, state: State<'_, AppState>) -> Result<(), Strin #[tauri::command] async fn sync_profile(app: AppHandle, state: State<'_, AppState>) -> Result { - let existing = { - let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?; - session.clone().ok_or_else(|| "No enrolled profile is available yet".to_string())? - }; - - let client = Client::builder() - .use_rustls_tls() - .build() - .map_err(|err| err.to_string())?; - - let response = client - .get(format!("{}/api/v1/me/profile", existing.server_url.trim_end_matches('/'))) - .bearer_auth(&existing.access_token) - .send() - .await - .map_err(|err| format!("Profile sync failed: {}", err))?; - - if !response.status().is_success() { - let status = response.status(); - let body = response - .text() - .await - .unwrap_or_else(|_| "".into()); - return Err(format!("Profile sync failed with status {}: {}", status, body)); - } - - let enroll = response - .json::() - .await - .map_err(|err| format!("Unable to decode profile sync response: {}", err))?; - - let profile_content = materialize_profile(&enroll.profile.content, &existing.private_key); - let profile_path = write_profile(&app, &profile_content)?; - let result = EnrollmentResult { - assigned_ip: enroll.peer.assigned_ip, - resources: enroll.resources.into_iter().map(|resource| resource.value).collect(), - profile_revision: enroll.peer.profile_revision, - gateway_endpoint: enroll.peer.gateway.endpoint, - profile_path: profile_path.display().to_string(), - last_sync_time: now_label(), - tunnel_strategy: tunnel_manager::current_tunnel_strategy().into(), - }; - - let session_state = SessionState { - access_token: existing.access_token, - refresh_token: existing.refresh_token, - server_url: existing.server_url, - profile_path: result.profile_path.clone(), - private_key: existing.private_key, - enrollment: result.clone(), - }; - - write_session_state(&app, &session_state)?; - let mut session = state.session.lock().map_err(|_| "Unable to store client state".to_string())?; - *session = Some(session_state); - drop(session); + let session_state = sync_current_session(&app).await?; refresh_tray_menu(&app); - - Ok(result) + Ok(session_state.enrollment) } #[tauri::command] -fn connect_tunnel(app: AppHandle, state: State<'_, AppState>) -> Result<(), String> { - let profile_path = { - let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?; - let session = session.as_ref().ok_or_else(|| "No enrolled profile is available yet".to_string())?; - session.profile_path.clone() - }; - let result = tunnel_manager::connect(&app, std::path::Path::new(&profile_path)); +async fn connect_tunnel(app: AppHandle) -> Result { + let session_state = sync_current_session(&app).await?; + let result = tunnel_manager::connect(&app, std::path::Path::new(&session_state.profile_path)); refresh_tray_menu(&app); - result + result?; + Ok(session_state.enrollment) } #[tauri::command] @@ -464,6 +405,67 @@ fn current_metrics(app: &AppHandle) -> Result { }) } +async fn sync_current_session(app: &AppHandle) -> Result { + let existing = { + let state = app.state::(); + let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?; + session.clone().ok_or_else(|| "No enrolled profile is available yet".to_string())? + }; + + let client = Client::builder() + .use_rustls_tls() + .build() + .map_err(|err| err.to_string())?; + + let response = client + .get(format!("{}/api/v1/me/profile", existing.server_url.trim_end_matches('/'))) + .bearer_auth(&existing.access_token) + .send() + .await + .map_err(|err| format!("Profile sync failed: {}", err))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response + .text() + .await + .unwrap_or_else(|_| "".into()); + return Err(format!("Profile sync failed with status {}: {}", status, body)); + } + + let enroll = response + .json::() + .await + .map_err(|err| format!("Unable to decode profile sync response: {}", err))?; + + let profile_content = materialize_profile(&enroll.profile.content, &existing.private_key); + let profile_path = write_profile(app, &profile_content)?; + let result = EnrollmentResult { + assigned_ip: enroll.peer.assigned_ip, + resources: enroll.resources.into_iter().map(|resource| resource.value).collect(), + profile_revision: enroll.peer.profile_revision, + gateway_endpoint: enroll.peer.gateway.endpoint, + profile_path: profile_path.display().to_string(), + last_sync_time: now_label(), + tunnel_strategy: tunnel_manager::current_tunnel_strategy().into(), + }; + + let session_state = SessionState { + access_token: existing.access_token, + refresh_token: existing.refresh_token, + server_url: existing.server_url, + profile_path: result.profile_path.clone(), + private_key: existing.private_key, + enrollment: result.clone(), + }; + + write_session_state(app, &session_state)?; + let state = app.state::(); + let mut session = state.session.lock().map_err(|_| "Unable to store client state".to_string())?; + *session = Some(session_state.clone()); + Ok(session_state) +} + fn update_tray_menu(app: &AppHandle, metrics: TunnelMetrics) -> Result<(), String> { let state = app.state::(); let tray = state.tray.lock().map_err(|_| "Unable to update tray state".to_string())?; @@ -510,15 +512,20 @@ fn toggle_tray_connection(app: &AppHandle) { Err(_) => return, }; - let result = if metrics.active { - tunnel_manager::disconnect(app, std::path::Path::new(&profile_path)) - } else { - tunnel_manager::connect(app, std::path::Path::new(&profile_path)) - }; - - if result.is_ok() { - refresh_tray_menu(app); + if metrics.active { + if tunnel_manager::disconnect(app, std::path::Path::new(&profile_path)).is_ok() { + refresh_tray_menu(app); + } + return; } + + let app_handle = app.clone(); + tauri::async_runtime::spawn(async move { + if let Ok(session_state) = sync_current_session(&app_handle).await { + let _ = tunnel_manager::connect(&app_handle, std::path::Path::new(&session_state.profile_path)); + refresh_tray_menu(&app_handle); + } + }); } fn restore_webview_window(window: &WebviewWindow) { diff --git a/desktop-client/src/App.tsx b/desktop-client/src/App.tsx index 42eb795..5bb6e0c 100644 --- a/desktop-client/src/App.tsx +++ b/desktop-client/src/App.tsx @@ -183,7 +183,12 @@ export function App() { async function toggleConnection() { const command = connected ? "disconnect_tunnel" : "connect_tunnel"; try { - await invoke(command); + if (!connected) { + const syncedState = await invoke("connect_tunnel"); + setState(syncedState); + } else { + await invoke(command); + } const active = await waitForTunnelStatus(!connected); setConnected(active); if (!connected && !active) {