feat: add service catalog management with policy integration for domain-based resource access control

Add ServiceCatalogItem type and services CRUD API endpoints (list, create, update, delete). Extend Policy type to include services array with domain, upstream_ip, proxy_ip, and ports metadata.

Add ServicesPage component with table view and create/edit modals for managing service definitions. Include service name, domain, proxy, and upstream columns with port parsing logic.

Integrate service selection
This commit is contained in:
2026-03-18 13:09:54 +01:00
parent 0ac93dfeb6
commit 6cf49ff3e0
25 changed files with 1375 additions and 99 deletions

View File

@@ -55,6 +55,15 @@ export type Policy = {
full_tunnel: boolean; full_tunnel: boolean;
is_active: boolean; is_active: boolean;
destinations?: string[]; destinations?: string[];
services?: Array<{
id: string;
name: string;
domain: string;
upstream_ip: string;
proxy_ip: string;
ports: number[];
description: string;
}>;
targets?: Array<{ targets?: Array<{
type: string; type: string;
id: string; id: string;
@@ -62,6 +71,17 @@ export type Policy = {
}>; }>;
}; };
export type ServiceCatalogItem = {
id: string;
name: string;
description: string;
domain: string;
upstream_ip: string;
proxy_ip: string;
ports: number[];
is_active: boolean;
};
export type Group = { export type Group = {
id: string; id: string;
name: string; name: string;
@@ -168,6 +188,37 @@ export const api = {
method: "DELETE" method: "DELETE"
}), }),
devices: () => request<Device[]>("/admin/devices"), devices: () => request<Device[]>("/admin/devices"),
services: () => request<ServiceCatalogItem[]>("/admin/services"),
createService: (payload: {
name: string;
description: string;
domain: string;
upstream_ip: string;
proxy_ip: string;
ports: number[];
is_active?: boolean;
}) =>
request<ServiceCatalogItem>("/admin/services", {
method: "POST",
body: JSON.stringify(payload)
}),
updateService: (serviceId: string, payload: {
name?: string;
description?: string;
domain?: string;
upstream_ip?: string;
proxy_ip?: string;
ports?: number[];
is_active?: boolean;
}) =>
request<ServiceCatalogItem>(`/admin/services/${serviceId}`, {
method: "PATCH",
body: JSON.stringify(payload)
}),
deleteService: (serviceId: string) =>
request<{ ok: boolean }>(`/admin/services/${serviceId}`, {
method: "DELETE"
}),
deviceProfile: (deviceId: string) => request<DeviceProfile>(`/admin/devices/${deviceId}/profile`), deviceProfile: (deviceId: string) => request<DeviceProfile>(`/admin/devices/${deviceId}/profile`),
revokeDevice: (deviceId: string) => revokeDevice: (deviceId: string) =>
request<{ ok: boolean }>(`/admin/devices/${deviceId}/revoke`, { request<{ ok: boolean }>(`/admin/devices/${deviceId}/revoke`, {
@@ -187,6 +238,7 @@ export const api = {
effect: string; effect: string;
full_tunnel: boolean; full_tunnel: boolean;
destinations: string[]; destinations: string[];
service_ids: string[];
targets: Array<{ type: string; id: string }>; targets: Array<{ type: string; id: string }>;
}) => }) =>
request<Policy>("/admin/policies", { request<Policy>("/admin/policies", {
@@ -201,6 +253,7 @@ export const api = {
full_tunnel?: boolean; full_tunnel?: boolean;
is_active?: boolean; is_active?: boolean;
destinations?: string[]; destinations?: string[];
service_ids?: string[];
targets?: Array<{ type: string; id: string }>; targets?: Array<{ type: string; id: string }>;
}) => }) =>
request<Policy>(`/admin/policies/${policyId}`, { request<Policy>(`/admin/policies/${policyId}`, {

View File

@@ -10,6 +10,7 @@ import { DevicesPage } from "../features/devices/DevicesPage";
import { GatewaysPage } from "../features/gateways/GatewaysPage"; import { GatewaysPage } from "../features/gateways/GatewaysPage";
import { GroupsPage } from "../features/groups/GroupsPage"; import { GroupsPage } from "../features/groups/GroupsPage";
import { PoliciesPage } from "../features/policies/PoliciesPage"; import { PoliciesPage } from "../features/policies/PoliciesPage";
import { ServicesPage } from "../features/services/ServicesPage";
import { SettingsPage } from "../features/settings/SettingsPage"; import { SettingsPage } from "../features/settings/SettingsPage";
import { UsersPage } from "../features/users/UsersPage"; import { UsersPage } from "../features/users/UsersPage";
@@ -44,6 +45,7 @@ export function App() {
<Route path="/users" element={<UsersPage />} /> <Route path="/users" element={<UsersPage />} />
<Route path="/groups" element={<GroupsPage />} /> <Route path="/groups" element={<GroupsPage />} />
<Route path="/devices" element={<DevicesPage />} /> <Route path="/devices" element={<DevicesPage />} />
<Route path="/services" element={<ServicesPage />} />
<Route path="/policies" element={<PoliciesPage />} /> <Route path="/policies" element={<PoliciesPage />} />
<Route path="/gateways" element={<GatewaysPage />} /> <Route path="/gateways" element={<GatewaysPage />} />
<Route path="/audit" element={<AuditPage />} /> <Route path="/audit" element={<AuditPage />} />

View File

@@ -5,6 +5,7 @@ const items = [
["Users", "/users"], ["Users", "/users"],
["Groups", "/groups"], ["Groups", "/groups"],
["Devices", "/devices"], ["Devices", "/devices"],
["Services", "/services"],
["Policies", "/policies"], ["Policies", "/policies"],
["Gateways", "/gateways"], ["Gateways", "/gateways"],
["Audit", "/audit"], ["Audit", "/audit"],

View File

@@ -28,10 +28,15 @@ export function PoliciesPage() {
queryKey: ["groups"], queryKey: ["groups"],
queryFn: api.groups queryFn: api.groups
}); });
const servicesQuery = useQuery({
queryKey: ["services"],
queryFn: api.services
});
const [form, setForm] = useState({ const [form, setForm] = useState({
name: "", name: "",
description: "", description: "",
destinations: "", destinations: "",
serviceIds: [] as string[],
targetType: "user", targetType: "user",
targetIds: [] as string[], targetIds: [] as string[],
fullTunnel: false fullTunnel: false
@@ -42,6 +47,7 @@ export function PoliciesPage() {
name: "", name: "",
description: "", description: "",
destinations: "", destinations: "",
serviceIds: [] as string[],
fullTunnel: false, fullTunnel: false,
isActive: true, isActive: true,
targetType: "user", targetType: "user",
@@ -52,17 +58,17 @@ export function PoliciesPage() {
mutationFn: api.createPolicy, mutationFn: api.createPolicy,
onSuccess: () => { onSuccess: () => {
setCreateOpen(false); setCreateOpen(false);
setForm({ name: "", description: "", destinations: "", targetType: "user", targetIds: [], fullTunnel: false }); setForm({ name: "", description: "", destinations: "", serviceIds: [], targetType: "user", targetIds: [], fullTunnel: false });
void queryClient.invalidateQueries({ queryKey: ["policies"] }); void queryClient.invalidateQueries({ queryKey: ["policies"] });
} }
}); });
const updateMutation = useMutation({ const updateMutation = useMutation({
mutationFn: ({ policyId, payload }: { policyId: string; payload: { name: string; description: string; destinations: string[]; full_tunnel: boolean; is_active: boolean; targets: Array<{ type: string; id: string }> } }) => mutationFn: ({ policyId, payload }: { policyId: string; payload: { name: string; description: string; destinations: string[]; service_ids: string[]; full_tunnel: boolean; is_active: boolean; targets: Array<{ type: string; id: string }> } }) =>
api.updatePolicy(policyId, payload), api.updatePolicy(policyId, payload),
onSuccess: () => { onSuccess: () => {
setEditingPolicyId(null); setEditingPolicyId(null);
setEditForm({ name: "", description: "", destinations: "", fullTunnel: false, isActive: true, targetType: "user", targetIds: [] }); setEditForm({ name: "", description: "", destinations: "", serviceIds: [], fullTunnel: false, isActive: true, targetType: "user", targetIds: [] });
void queryClient.invalidateQueries({ queryKey: ["policies"] }); void queryClient.invalidateQueries({ queryKey: ["policies"] });
} }
}); });
@@ -78,12 +84,16 @@ export function PoliciesPage() {
id: policy.id, id: policy.id,
name: policy.name, name: policy.name,
targets: policy.targets?.length ? policy.targets.map((target) => `${target.type}: ${target.name ?? target.id}`).join(", ") : "No targets", targets: policy.targets?.length ? policy.targets.map((target) => `${target.type}: ${target.name ?? target.id}`).join(", ") : "No targets",
destinations: policy.destinations?.join(", ") ?? (policy.full_tunnel ? "0.0.0.0/0" : "-"), destinations: [
...(policy.services?.map((service) => service.name) ?? []),
...(policy.destinations ?? [])
].join(", ") || (policy.full_tunnel ? "0.0.0.0/0" : "-"),
mode: policy.full_tunnel ? "Full tunnel" : "Split tunnel" mode: policy.full_tunnel ? "Full tunnel" : "Split tunnel"
})) ?? []; })) ?? [];
const selectableUsers = useMemo(() => usersQuery.data ?? [], [usersQuery.data]); const selectableUsers = useMemo(() => usersQuery.data ?? [], [usersQuery.data]);
const selectableGroups = useMemo(() => groupsQuery.data ?? [], [groupsQuery.data]); const selectableGroups = useMemo(() => groupsQuery.data ?? [], [groupsQuery.data]);
const selectableServices = useMemo(() => servicesQuery.data ?? [], [servicesQuery.data]);
const selectableTargets = form.targetType === "group" ? selectableGroups : selectableUsers; const selectableTargets = form.targetType === "group" ? selectableGroups : selectableUsers;
const editableTargets = editForm.targetType === "group" ? selectableGroups : selectableUsers; const editableTargets = editForm.targetType === "group" ? selectableGroups : selectableUsers;
@@ -99,6 +109,7 @@ export function PoliciesPage() {
effect: "allow", effect: "allow",
full_tunnel: form.fullTunnel, full_tunnel: form.fullTunnel,
destinations: form.fullTunnel ? ["0.0.0.0/0"] : form.destinations.split(",").map((value) => value.trim()).filter(Boolean), destinations: form.fullTunnel ? ["0.0.0.0/0"] : form.destinations.split(",").map((value) => value.trim()).filter(Boolean),
service_ids: form.serviceIds,
targets: form.targetIds.map((id) => ({ type: form.targetType, id })) targets: form.targetIds.map((id) => ({ type: form.targetType, id }))
}); });
} }
@@ -114,6 +125,7 @@ export function PoliciesPage() {
name: policy.name, name: policy.name,
description: policy.description, description: policy.description,
destinations: policy.destinations?.join(", ") ?? "", destinations: policy.destinations?.join(", ") ?? "",
serviceIds: policy.services?.map((service) => service.id) ?? [],
fullTunnel: policy.full_tunnel, fullTunnel: policy.full_tunnel,
isActive: policy.is_active, isActive: policy.is_active,
targetType, targetType,
@@ -132,6 +144,7 @@ export function PoliciesPage() {
name: editForm.name, name: editForm.name,
description: editForm.description, description: editForm.description,
destinations: editForm.fullTunnel ? ["0.0.0.0/0"] : editForm.destinations.split(",").map((value) => value.trim()).filter(Boolean), destinations: editForm.fullTunnel ? ["0.0.0.0/0"] : editForm.destinations.split(",").map((value) => value.trim()).filter(Boolean),
service_ids: editForm.serviceIds,
full_tunnel: editForm.fullTunnel, full_tunnel: editForm.fullTunnel,
is_active: editForm.isActive, is_active: editForm.isActive,
targets: editForm.targetIds.map((id) => ({ type: editForm.targetType, id })) targets: editForm.targetIds.map((id) => ({ type: editForm.targetType, id }))
@@ -158,6 +171,25 @@ export function PoliciesPage() {
})); }));
} }
function toggleService(id: string, editing = false) {
if (editing) {
setEditForm((value) => ({
...value,
serviceIds: value.serviceIds.includes(id)
? value.serviceIds.filter((item) => item !== id)
: [...value.serviceIds, id]
}));
return;
}
setForm((value) => ({
...value,
serviceIds: value.serviceIds.includes(id)
? value.serviceIds.filter((item) => item !== id)
: [...value.serviceIds, id]
}));
}
return ( return (
<Page <Page
title="Policies" title="Policies"
@@ -214,6 +246,17 @@ export function PoliciesPage() {
onChange={(event) => setForm((value) => ({ ...value, destinations: event.target.value }))} onChange={(event) => setForm((value) => ({ ...value, destinations: event.target.value }))}
disabled={form.fullTunnel} disabled={form.fullTunnel}
/> />
<div className="selection-panel">
<p className="eyebrow">Allowed services</p>
<div className="selection-list">
{selectableServices.map((service) => (
<label className="selection-item" key={service.id}>
<input type="checkbox" checked={form.serviceIds.includes(service.id)} onChange={() => toggleService(service.id)} />
<span>{service.name} ({service.domain})</span>
</label>
))}
</div>
</div>
<label className="checkbox"> <label className="checkbox">
<input type="checkbox" checked={form.fullTunnel} onChange={(event) => setForm((value) => ({ ...value, fullTunnel: event.target.checked }))} /> <input type="checkbox" checked={form.fullTunnel} onChange={(event) => setForm((value) => ({ ...value, fullTunnel: event.target.checked }))} />
Full tunnel Full tunnel
@@ -251,6 +294,17 @@ export function PoliciesPage() {
onChange={(event) => setEditForm((value) => ({ ...value, destinations: event.target.value }))} onChange={(event) => setEditForm((value) => ({ ...value, destinations: event.target.value }))}
disabled={editForm.fullTunnel} disabled={editForm.fullTunnel}
/> />
<div className="selection-panel">
<p className="eyebrow">Allowed services</p>
<div className="selection-list">
{selectableServices.map((service) => (
<label className="selection-item" key={service.id}>
<input type="checkbox" checked={editForm.serviceIds.includes(service.id)} onChange={() => toggleService(service.id, true)} />
<span>{service.name} ({service.domain})</span>
</label>
))}
</div>
</div>
<label className="checkbox"> <label className="checkbox">
<input type="checkbox" checked={editForm.fullTunnel} onChange={(event) => setEditForm((value) => ({ ...value, fullTunnel: event.target.checked }))} /> <input type="checkbox" checked={editForm.fullTunnel} onChange={(event) => setEditForm((value) => ({ ...value, fullTunnel: event.target.checked }))} />
Full tunnel Full tunnel

View File

@@ -0,0 +1,203 @@
import { FormEvent, useState } from "react";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { api } from "../../api/client";
import { Modal } from "../../components/Modal";
import { Page } from "../../components/Page";
import { Table } from "../../components/Table";
const columns = [
{ key: "name", label: "Service" },
{ key: "domain", label: "Domain" },
{ key: "proxy", label: "Proxy" },
{ key: "upstream", label: "Upstream" },
{ key: "actions", label: "Actions" }
];
export function ServicesPage() {
const queryClient = useQueryClient();
const query = useQuery({
queryKey: ["services"],
queryFn: api.services
});
const [createOpen, setCreateOpen] = useState(false);
const [editingServiceId, setEditingServiceId] = useState<string | null>(null);
const [form, setForm] = useState({
name: "",
description: "",
domain: "",
upstream_ip: "",
proxy_ip: "",
ports: "80,443",
is_active: true
});
const createMutation = useMutation({
mutationFn: api.createService,
onSuccess: () => {
setCreateOpen(false);
resetForm();
void queryClient.invalidateQueries({ queryKey: ["services"] });
}
});
const updateMutation = useMutation({
mutationFn: ({ serviceId, payload }: { serviceId: string; payload: Parameters<typeof api.updateService>[1] }) =>
api.updateService(serviceId, payload),
onSuccess: () => {
setEditingServiceId(null);
resetForm();
void queryClient.invalidateQueries({ queryKey: ["services"] });
}
});
const deleteMutation = useMutation({
mutationFn: api.deleteService,
onSuccess: () => {
void queryClient.invalidateQueries({ queryKey: ["services"] });
}
});
function resetForm() {
setForm({
name: "",
description: "",
domain: "",
upstream_ip: "",
proxy_ip: "",
ports: "80,443",
is_active: true
});
}
function parsePorts(raw: string) {
return raw
.split(",")
.map((value) => Number.parseInt(value.trim(), 10))
.filter((value) => Number.isFinite(value) && value > 0);
}
function onCreate(event: FormEvent) {
event.preventDefault();
createMutation.mutate({
...form,
ports: parsePorts(form.ports)
});
}
function startEdit(serviceId: string) {
const service = query.data?.find((item) => item.id === serviceId);
if (!service) {
return;
}
setEditingServiceId(serviceId);
setForm({
name: service.name,
description: service.description,
domain: service.domain,
upstream_ip: service.upstream_ip,
proxy_ip: service.proxy_ip,
ports: service.ports.join(","),
is_active: service.is_active
});
}
function onEdit(event: FormEvent) {
event.preventDefault();
if (!editingServiceId) {
return;
}
updateMutation.mutate({
serviceId: editingServiceId,
payload: {
...form,
ports: parsePorts(form.ports)
}
});
}
const rows = query.data?.map((service) => ({
id: service.id,
name: service.name,
domain: service.domain,
proxy: `${service.proxy_ip}:${service.ports.join(",")}`,
upstream: service.upstream_ip
})) ?? [];
return (
<Page
title="Services"
subtitle="Define named internal services with domain, proxy hop, and upstream metadata."
actions={(
<div className="action-row">
<button className="button" type="button" onClick={() => setCreateOpen(true)}>New service</button>
</div>
)}
>
{query.isError ? <p className="notice">Unable to load services from the API.</p> : null}
{createMutation.isError ? <p className="notice">Unable to create service.</p> : null}
{updateMutation.isError ? <p className="notice">Unable to update service.</p> : null}
{deleteMutation.isError ? <p className="notice">Unable to delete service.</p> : null}
<Table
columns={columns}
rows={rows}
renderCell={(row, column) => {
if (column.key === "actions") {
return (
<div className="action-row">
<button className="ghost-button" type="button" onClick={() => startEdit(row.id)}>Edit</button>
<button className="ghost-button" type="button" onClick={() => deleteMutation.mutate(row.id)}>Delete</button>
</div>
);
}
return <span>{row[column.key as keyof (typeof rows)[number]]}</span>;
}}
/>
{createOpen ? (
<Modal title="Create service" subtitle="Register a named domain-based resource for NexaVPN policies." onClose={() => setCreateOpen(false)}>
<form className="stacked-form" onSubmit={onCreate}>
<input placeholder="service name" value={form.name} onChange={(event) => setForm((value) => ({ ...value, name: event.target.value }))} />
<input placeholder="description" value={form.description} onChange={(event) => setForm((value) => ({ ...value, description: event.target.value }))} />
<input placeholder="domain" value={form.domain} onChange={(event) => setForm((value) => ({ ...value, domain: event.target.value }))} />
<input placeholder="upstream ip" value={form.upstream_ip} onChange={(event) => setForm((value) => ({ ...value, upstream_ip: event.target.value }))} />
<input placeholder="proxy ip" value={form.proxy_ip} onChange={(event) => setForm((value) => ({ ...value, proxy_ip: event.target.value }))} />
<input placeholder="ports: 80,443" value={form.ports} onChange={(event) => setForm((value) => ({ ...value, ports: event.target.value }))} />
<label className="checkbox">
<input type="checkbox" checked={form.is_active} onChange={(event) => setForm((value) => ({ ...value, is_active: event.target.checked }))} />
Active
</label>
<div className="action-row">
<button className="button" type="submit" disabled={createMutation.isPending}>Create service</button>
<button className="ghost-button" type="button" onClick={() => setCreateOpen(false)}>Cancel</button>
</div>
</form>
</Modal>
) : null}
{editingServiceId ? (
<Modal title="Edit service" subtitle="Update the service metadata and proxy target." onClose={() => setEditingServiceId(null)}>
<form className="stacked-form" onSubmit={onEdit}>
<input placeholder="service name" value={form.name} onChange={(event) => setForm((value) => ({ ...value, name: event.target.value }))} />
<input placeholder="description" value={form.description} onChange={(event) => setForm((value) => ({ ...value, description: event.target.value }))} />
<input placeholder="domain" value={form.domain} onChange={(event) => setForm((value) => ({ ...value, domain: event.target.value }))} />
<input placeholder="upstream ip" value={form.upstream_ip} onChange={(event) => setForm((value) => ({ ...value, upstream_ip: event.target.value }))} />
<input placeholder="proxy ip" value={form.proxy_ip} onChange={(event) => setForm((value) => ({ ...value, proxy_ip: event.target.value }))} />
<input placeholder="ports" value={form.ports} onChange={(event) => setForm((value) => ({ ...value, ports: event.target.value }))} />
<label className="checkbox">
<input type="checkbox" checked={form.is_active} onChange={(event) => setForm((value) => ({ ...value, is_active: event.target.checked }))} />
Active
</label>
<div className="action-row">
<button className="button" type="submit" disabled={updateMutation.isPending}>Save changes</button>
<button className="ghost-button" type="button" onClick={() => setEditingServiceId(null)}>Cancel</button>
</div>
</form>
</Modal>
) : null}
</Page>
);
}

View File

@@ -16,6 +16,7 @@ import (
"nexavpn/backend/internal/httpserver" "nexavpn/backend/internal/httpserver"
"nexavpn/backend/internal/ipam" "nexavpn/backend/internal/ipam"
"nexavpn/backend/internal/policy" "nexavpn/backend/internal/policy"
"nexavpn/backend/internal/servicecatalog"
"nexavpn/backend/internal/user" "nexavpn/backend/internal/user"
) )
@@ -31,12 +32,16 @@ func New(cfg config.Config) (*App, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := db.EnsureSchema(ctx, pool); err != nil {
return nil, err
}
authRepo := auth.NewPGRepository(pool) authRepo := auth.NewPGRepository(pool)
authService := auth.NewService(authRepo, cfg.JWTSecret, cfg.JWTIssuer, cfg.AccessTokenTTL, cfg.RefreshTokenTTL) authService := auth.NewService(authRepo, cfg.JWTSecret, cfg.JWTIssuer, cfg.AccessTokenTTL, cfg.RefreshTokenTTL)
userService := user.NewService(user.NewPGRepository(pool)) userService := user.NewService(user.NewPGRepository(pool))
groupService := group.NewService(group.NewPGRepository(pool)) groupService := group.NewService(group.NewPGRepository(pool))
serviceCatalogService := servicecatalog.NewService(servicecatalog.NewPGRepository(pool))
policyService := policy.NewService(policy.NewPGRepository(pool)) policyService := policy.NewService(policy.NewPGRepository(pool))
gatewayService := gateway.NewService(gateway.NewPGRepository(pool)) gatewayService := gateway.NewService(gateway.NewPGRepository(pool))
deviceService := device.NewService(device.NewPGRepository(pool), policyService, gatewayService, ipam.NewService()) deviceService := device.NewService(device.NewPGRepository(pool), policyService, gatewayService, ipam.NewService())
@@ -47,6 +52,7 @@ func New(cfg config.Config) (*App, error) {
User: user.NewHandler(userService, auditService), User: user.NewHandler(userService, auditService),
Device: device.NewHandler(deviceService, auditService), Device: device.NewHandler(deviceService, auditService),
Group: group.NewHandler(groupService, auditService), Group: group.NewHandler(groupService, auditService),
Service: servicecatalog.NewHandler(serviceCatalogService),
Policy: policy.NewHandler(policyService, auditService), Policy: policy.NewHandler(policyService, auditService),
Gateway: gateway.NewHandler(gatewayService, cfg.GatewayBootstrapToken), Gateway: gateway.NewHandler(gatewayService, cfg.GatewayBootstrapToken),
Audit: audit.NewHandler(auditService), Audit: audit.NewHandler(auditService),

View File

@@ -9,3 +9,35 @@ import (
func Connect(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) { func Connect(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) {
return pgxpool.New(ctx, databaseURL) return pgxpool.New(ctx, databaseURL)
} }
func EnsureSchema(ctx context.Context, db *pgxpool.Pool) error {
_, err := db.Exec(ctx, `
create table if not exists services (
id uuid primary key default gen_random_uuid(),
name text not null unique,
description text not null default '',
domain text not null,
upstream_ip inet not null,
proxy_ip inet not null,
ports integer[] not null default '{80,443}',
is_active boolean not null default true,
created_at timestamptz not null default now(),
updated_at timestamptz not null default now(),
deleted_at timestamptz
);
create table if not exists policy_services (
id uuid primary key default gen_random_uuid(),
policy_id uuid not null references policies(id) on delete cascade,
service_id uuid not null references services(id) on delete cascade,
created_at timestamptz not null default now(),
unique(policy_id, service_id)
);
create index if not exists idx_services_domain on services(domain) where deleted_at is null;
create unique index if not exists idx_services_domain_unique on services(lower(domain)) where deleted_at is null;
create index if not exists idx_policy_services_policy_id on policy_services(policy_id);
create index if not exists idx_policy_services_service_id on policy_services(service_id);
`)
return err
}

View File

@@ -71,7 +71,11 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
if len(selectedDestinations) == 0 { if len(selectedDestinations) == 0 {
selectedDestinations = destinations selectedDestinations = destinations
} }
profileAllowedIPs := mergeProfileAllowedIPs(selectedDestinations, selectedGateway.DNSServers, alwaysAllowWebProxyTargets()) selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
profileAllowedIPs := mergeProfileAllowedIPs(
append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
selectedGateway.DNSServers,
)
enrollment.Peer = PeerView{ enrollment.Peer = PeerView{
AssignedIP: assignedIP, AssignedIP: assignedIP,
@@ -85,7 +89,7 @@ func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequ
}, },
ProfileRevision: 1, ProfileRevision: 1,
} }
enrollment.Resources = resourcesFromDestinations(selectedDestinations) enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
enrollment.AvailableProfiles = availableProfiles enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID enrollment.SelectedProfileID = selectedProfileID
@@ -215,10 +219,14 @@ func (s *Service) applyCurrentPolicy(ctx context.Context, enrollment EnrollmentR
selectedDestinations = []string{"172.16.10.0/24"} selectedDestinations = []string{"172.16.10.0/24"}
} }
enrollment.Resources = resourcesFromDestinations(selectedDestinations) selectedServices := servicesForSelectedProfile(availableProfiles, selectedProfileID)
enrollment.Resources = resourcesFromProfile(selectedDestinations, selectedServices)
enrollment.AvailableProfiles = availableProfiles enrollment.AvailableProfiles = availableProfiles
enrollment.SelectedProfileID = selectedProfileID enrollment.SelectedProfileID = selectedProfileID
enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(selectedDestinations, enrollment.Peer.DNSServers, alwaysAllowWebProxyTargets()) enrollment.Peer.AllowedIPs = mergeProfileAllowedIPs(
append(selectedDestinations, proxyRoutesForServices(selectedServices)...),
enrollment.Peer.DNSServers,
)
return withDebugProfile(enrollment), nil return withDebugProfile(enrollment), nil
} }
@@ -230,12 +238,25 @@ func (s *Service) resolveAccessProfiles(ctx context.Context, userID uuid.UUID, d
availableProfiles := make([]AccessProfile, 0, len(profiles)) availableProfiles := make([]AccessProfile, 0, len(profiles))
for _, profile := range profiles { for _, profile := range profiles {
services := make([]AccessService, 0, len(profile.Services))
for _, service := range profile.Services {
services = append(services, AccessService{
ID: service.ID,
Name: service.Name,
Description: service.Description,
Domain: service.Domain,
UpstreamIP: service.UpstreamIP,
ProxyIP: service.ProxyIP,
Ports: service.Ports,
})
}
availableProfiles = append(availableProfiles, AccessProfile{ availableProfiles = append(availableProfiles, AccessProfile{
ID: profile.ID, ID: profile.ID,
Name: profile.Name, Name: profile.Name,
Description: profile.Description, Description: profile.Description,
FullTunnel: profile.FullTunnel, FullTunnel: profile.FullTunnel,
Destinations: profile.Destinations, Destinations: profile.Destinations,
Services: services,
}) })
} }
@@ -273,9 +294,64 @@ func resourcesFromDestinations(destinations []string) []Resource {
return resources return resources
} }
func mergeProfileAllowedIPs(destinations []string, dnsServers []string, webProxyTargets []string) []string { func resourcesFromProfile(destinations []string, services []AccessService) []Resource {
seen := make(map[string]struct{}, len(destinations)+len(dnsServers)+len(webProxyTargets)) resources := resourcesFromDestinations(destinations)
merged := make([]string, 0, len(destinations)+len(dnsServers)+len(webProxyTargets)) for _, service := range services {
resources = append(resources, Resource{
Type: "service",
Value: service.Domain,
Label: service.Name,
Domain: service.Domain,
})
}
return resources
}
func servicesForSelectedProfile(profiles []AccessProfile, selectedProfileID *uuid.UUID) []AccessService {
if selectedProfileID == nil {
if len(profiles) == 0 {
return nil
}
return profiles[0].Services
}
for _, profile := range profiles {
if profile.ID == *selectedProfileID {
return profile.Services
}
}
return nil
}
func proxyRoutesForServices(services []AccessService) []string {
seen := make(map[string]struct{}, len(services))
routes := make([]string, 0, len(services))
for _, service := range services {
route := dnsServerRoute(effectiveServiceProxyIP(service.ProxyIP))
if route == "" {
continue
}
if _, ok := seen[route]; ok {
continue
}
seen[route] = struct{}{}
routes = append(routes, route)
}
return routes
}
func effectiveServiceProxyIP(proxyIP string) string {
override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP"))
if override != "" {
return override
}
return proxyIP
}
func mergeProfileAllowedIPs(destinations []string, dnsServers []string) []string {
seen := make(map[string]struct{}, len(destinations)+len(dnsServers))
merged := make([]string, 0, len(destinations)+len(dnsServers))
for _, destination := range destinations { for _, destination := range destinations {
destination = strings.TrimSpace(destination) destination = strings.TrimSpace(destination)
@@ -301,18 +377,6 @@ func mergeProfileAllowedIPs(destinations []string, dnsServers []string, webProxy
merged = append(merged, route) merged = append(merged, route)
} }
for _, target := range webProxyTargets {
route := dnsServerRoute(target)
if route == "" {
continue
}
if _, exists := seen[route]; exists {
continue
}
seen[route] = struct{}{}
merged = append(merged, route)
}
return merged return merged
} }
@@ -326,25 +390,3 @@ func dnsServerRoute(value string) string {
} }
return value + "/32" return value + "/32"
} }
func alwaysAllowWebProxyTargets() []string {
raw := os.Getenv("NEXAVPN_ALWAYS_ALLOW_WEB_PROXY_IPS")
if strings.TrimSpace(raw) == "" {
return nil
}
seen := make(map[string]struct{})
targets := make([]string, 0)
for _, part := range strings.Split(raw, ",") {
value := strings.TrimSpace(part)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
targets = append(targets, value)
}
return targets
}

View File

@@ -34,6 +34,7 @@ type Resource struct {
Type string `json:"type"` Type string `json:"type"`
Value string `json:"value"` Value string `json:"value"`
Label string `json:"label"` Label string `json:"label"`
Domain string `json:"domain,omitempty"`
} }
type EnrollmentResponse struct { type EnrollmentResponse struct {
@@ -51,6 +52,17 @@ type AccessProfile struct {
Description string `json:"description"` Description string `json:"description"`
FullTunnel bool `json:"full_tunnel"` FullTunnel bool `json:"full_tunnel"`
Destinations []string `json:"destinations"` Destinations []string `json:"destinations"`
Services []AccessService `json:"services"`
}
type AccessService struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Domain string `json:"domain"`
UpstreamIP string `json:"upstream_ip"`
ProxyIP string `json:"proxy_ip"`
Ports []int `json:"ports"`
} }
type PeerView struct { type PeerView struct {

View File

@@ -94,7 +94,8 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
wp.public_key, wp.public_key,
set_masklen(wp.assigned_ip, 32)::text, set_masklen(wp.assigned_ip, 32)::text,
coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'), coalesce(array_agg(distinct pd.destination::text) filter (where pd.destination is not null), '{}'),
coalesce(g.dns_servers, '{}')::text[] coalesce(g.dns_servers, '{}')::text[],
s.value->>'profile_id'
from devices d from devices d
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
join gateways g on g.id = d.gateway_id join gateways g on g.id = d.gateway_id
@@ -115,7 +116,7 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
or p.id::text = s.value->>'profile_id' or p.id::text = s.value->>'profile_id'
) )
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active' where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers group by d.id, wp.public_key, wp.assigned_ip, g.dns_servers, s.value
`, gatewayID) `, gatewayID)
if err != nil { if err != nil {
return wireguard.GatewayBundle{}, err return wireguard.GatewayBundle{}, err
@@ -125,37 +126,70 @@ func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID)
for rows.Next() { for rows.Next() {
var peer wireguard.Peer var peer wireguard.Peer
var deviceID uuid.UUID var deviceID uuid.UUID
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &peer.DNSServers); err != nil { var selectedProfileID *string
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations, &peer.DNSServers, &selectedProfileID); err != nil {
return wireguard.GatewayBundle{}, err return wireguard.GatewayBundle{}, err
} }
peer.DeviceID = deviceID.String() peer.DeviceID = deviceID.String()
peer.WebProxyTargets = alwaysAllowWebProxyTargets() services, err := r.listAllowedServices(ctx, deviceID, selectedProfileID)
if err != nil {
return wireguard.GatewayBundle{}, err
}
peer.AllowedServices = services
bundle.Peers = append(bundle.Peers, peer) bundle.Peers = append(bundle.Peers, peer)
} }
return bundle, rows.Err() return bundle, rows.Err()
} }
func alwaysAllowWebProxyTargets() []string { func (r *PGRepository) listAllowedServices(ctx context.Context, deviceID uuid.UUID, selectedProfileID *string) ([]wireguard.AllowedService, error) {
raw := os.Getenv("NEXAVPN_ALWAYS_ALLOW_WEB_PROXY_IPS") rows, err := r.db.Query(ctx, `
if strings.TrimSpace(raw) == "" { select distinct
return nil s.name,
s.domain,
host(s.upstream_ip),
host(s.proxy_ip),
s.ports
from devices d
left join group_memberships gm on gm.user_id = d.user_id
join policy_targets pt on (
(pt.target_type = 'device' and pt.target_id = d.id) or
(pt.target_type = 'user' and pt.target_id = d.user_id) or
(pt.target_type = 'group' and pt.target_id = gm.group_id)
)
join policies p on p.id = pt.policy_id
and p.deleted_at is null
and p.is_active = true
and p.effect = 'allow'
join policy_services ps on ps.policy_id = p.id
join services s on s.id = ps.service_id and s.deleted_at is null and s.is_active = true
where d.id = $1
and ($2::text is null or p.id::text = $2::text)
order by s.name asc
`, deviceID, selectedProfileID)
if err != nil {
return nil, err
} }
defer rows.Close()
seen := make(map[string]struct{}) var items []wireguard.AllowedService
targets := make([]string, 0) for rows.Next() {
for _, part := range strings.Split(raw, ",") { var item wireguard.AllowedService
value := strings.TrimSpace(part) if err := rows.Scan(&item.Name, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports); err != nil {
if value == "" { return nil, err
continue
} }
if _, ok := seen[value]; ok { item.AccessProxyIP = effectiveAccessProxyIP(item.ProxyIP)
continue items = append(items, item)
} }
seen[value] = struct{}{} return items, rows.Err()
targets = append(targets, value) }
func effectiveAccessProxyIP(proxyIP string) string {
override := strings.TrimSpace(os.Getenv("NEXAVPN_ACCESS_PROXY_IP"))
if override != "" {
return override
} }
return targets return proxyIP
} }
func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) { func (r *PGRepository) Update(ctx context.Context, gatewayID uuid.UUID, input UpdateRequest) (Gateway, error) {

View File

@@ -12,6 +12,7 @@ import (
"nexavpn/backend/internal/gateway" "nexavpn/backend/internal/gateway"
"nexavpn/backend/internal/group" "nexavpn/backend/internal/group"
"nexavpn/backend/internal/policy" "nexavpn/backend/internal/policy"
"nexavpn/backend/internal/servicecatalog"
"nexavpn/backend/internal/user" "nexavpn/backend/internal/user"
) )
@@ -19,6 +20,7 @@ type Handlers struct {
Auth *auth.Handler Auth *auth.Handler
User *user.Handler User *user.Handler
Device *device.Handler Device *device.Handler
Service *servicecatalog.Handler
Policy *policy.Handler Policy *policy.Handler
Gateway *gateway.Handler Gateway *gateway.Handler
Group *group.Handler Group *group.Handler
@@ -68,6 +70,10 @@ func NewRouter(jwtSecret string, handlers Handlers) http.Handler {
r.Post("/groups", handlers.Group.Create) r.Post("/groups", handlers.Group.Create)
r.Patch("/groups/{id}", handlers.Group.Update) r.Patch("/groups/{id}", handlers.Group.Update)
r.Delete("/groups/{id}", handlers.Group.Delete) r.Delete("/groups/{id}", handlers.Group.Delete)
r.Get("/services", handlers.Service.List)
r.Post("/services", handlers.Service.Create)
r.Patch("/services/{id}", handlers.Service.Update)
r.Delete("/services/{id}", handlers.Service.Delete)
r.Get("/policies", handlers.Policy.List) r.Get("/policies", handlers.Policy.List)
r.Post("/policies", handlers.Policy.Create) r.Post("/policies", handlers.Policy.Create)
r.Patch("/policies/{id}", handlers.Policy.Update) r.Patch("/policies/{id}", handlers.Policy.Update)

View File

@@ -53,6 +53,11 @@ func (r *PGRepository) List(ctx context.Context) ([]Policy, error) {
if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.Priority, &item.Effect, &item.FullTunnel, &item.IsActive, &item.Destinations); err != nil { if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.Priority, &item.Effect, &item.FullTunnel, &item.IsActive, &item.Destinations); err != nil {
return nil, err return nil, err
} }
services, err := r.listServices(ctx, item.ID)
if err != nil {
return nil, err
}
item.Services = services
targets, err := r.listTargets(ctx, item.ID) targets, err := r.listTargets(ctx, item.ID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -88,6 +93,15 @@ func (r *PGRepository) Create(ctx context.Context, input CreateRequest, createdB
} }
} }
for _, serviceID := range input.ServiceIDs {
if _, err := tx.Exec(ctx, `
insert into policy_services (id, policy_id, service_id)
values ($1, $2, $3)
`, uuid.New(), policyID, serviceID); err != nil {
return Policy{}, err
}
}
for _, target := range input.Targets { for _, target := range input.Targets {
if _, err := tx.Exec(ctx, ` if _, err := tx.Exec(ctx, `
insert into policy_targets (id, policy_id, target_type, target_id) insert into policy_targets (id, policy_id, target_type, target_id)
@@ -141,6 +155,20 @@ func (r *PGRepository) Update(ctx context.Context, policyID uuid.UUID, input Upd
} }
} }
if input.ServiceIDs != nil {
if _, err := tx.Exec(ctx, `delete from policy_services where policy_id = $1`, policyID); err != nil {
return Policy{}, err
}
for _, serviceID := range input.ServiceIDs {
if _, err := tx.Exec(ctx, `
insert into policy_services (id, policy_id, service_id)
values ($1, $2, $3)
`, uuid.New(), policyID, serviceID); err != nil {
return Policy{}, err
}
}
}
if input.Targets != nil { if input.Targets != nil {
if _, err := tx.Exec(ctx, `delete from policy_targets where policy_id = $1`, policyID); err != nil { if _, err := tx.Exec(ctx, `delete from policy_targets where policy_id = $1`, policyID); err != nil {
return Policy{}, err return Policy{}, err
@@ -169,7 +197,9 @@ func (r *PGRepository) Delete(ctx context.Context, policyID uuid.UUID) error {
func (r *PGRepository) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) { func (r *PGRepository) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) {
query := ` query := `
select distinct pd.destination::text select distinct destination
from (
select pd.destination::text as destination
from policies p from policies p
join policy_destinations pd on pd.policy_id = p.id join policy_destinations pd on pd.policy_id = p.id
join policy_targets pt on pt.policy_id = p.id join policy_targets pt on pt.policy_id = p.id
@@ -188,7 +218,28 @@ func (r *PGRepository) ResolveDestinations(ctx context.Context, userID uuid.UUID
query += ` or (pt.target_type = 'device' and pt.target_id = $2)` query += ` or (pt.target_type = 'device' and pt.target_id = $2)`
args = append(args, *deviceID) args = append(args, *deviceID)
} }
query += `)` query += `)
union
select host(s.proxy_ip)::text || '/32' as destination
from policies p
join policy_services ps on ps.policy_id = p.id
join services s on s.id = ps.service_id and s.deleted_at is null and s.is_active = true
join policy_targets pt on pt.policy_id = p.id
where p.deleted_at is null
and p.is_active = true
and p.effect = 'allow'
and (
(pt.target_type = 'user' and pt.target_id = $1)
or (pt.target_type = 'group' and exists (
select 1 from group_memberships gm
where gm.group_id = pt.target_id and gm.user_id = $1
))
`
if deviceID != nil {
query += ` or (pt.target_type = 'device' and pt.target_id = $2)`
}
query += `)
) destinations`
rows, err := r.db.Query(ctx, query, args...) rows, err := r.db.Query(ctx, query, args...)
if err != nil { if err != nil {
@@ -249,6 +300,11 @@ func (r *PGRepository) ListSelectableProfiles(ctx context.Context, userID uuid.U
if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.FullTunnel, &item.Destinations); err != nil { if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.FullTunnel, &item.Destinations); err != nil {
return nil, err return nil, err
} }
services, err := r.listServices(ctx, item.ID)
if err != nil {
return nil, err
}
item.Services = services
profiles = append(profiles, item) profiles = append(profiles, item)
} }
return profiles, rows.Err() return profiles, rows.Err()
@@ -296,3 +352,34 @@ func (r *PGRepository) listTargets(ctx context.Context, policyID uuid.UUID) ([]T
return items, rows.Err() return items, rows.Err()
} }
func (r *PGRepository) listServices(ctx context.Context, policyID uuid.UUID) ([]PolicyService, error) {
rows, err := r.db.Query(ctx, `
select
s.id,
s.name,
s.domain,
host(s.upstream_ip),
host(s.proxy_ip),
s.ports,
s.description
from policy_services ps
join services s on s.id = ps.service_id
where ps.policy_id = $1 and s.deleted_at is null
order by s.name asc
`, policyID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []PolicyService
for rows.Next() {
var item PolicyService
if err := rows.Scan(&item.ID, &item.Name, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports, &item.Description); err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}

View File

@@ -17,9 +17,20 @@ type Policy struct {
FullTunnel bool `json:"full_tunnel"` FullTunnel bool `json:"full_tunnel"`
IsActive bool `json:"is_active"` IsActive bool `json:"is_active"`
Destinations []string `json:"destinations"` Destinations []string `json:"destinations"`
Services []PolicyService `json:"services"`
Targets []Target `json:"targets"` Targets []Target `json:"targets"`
} }
type PolicyService struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Domain string `json:"domain"`
UpstreamIP string `json:"upstream_ip"`
ProxyIP string `json:"proxy_ip"`
Ports []int `json:"ports"`
Description string `json:"description"`
}
type CreateRequest struct { type CreateRequest struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
@@ -27,6 +38,7 @@ type CreateRequest struct {
Effect string `json:"effect"` Effect string `json:"effect"`
FullTunnel bool `json:"full_tunnel"` FullTunnel bool `json:"full_tunnel"`
Destinations []string `json:"destinations"` Destinations []string `json:"destinations"`
ServiceIDs []uuid.UUID `json:"service_ids"`
Targets []Target `json:"targets"` Targets []Target `json:"targets"`
} }
@@ -38,6 +50,7 @@ type UpdateRequest struct {
FullTunnel *bool `json:"full_tunnel"` FullTunnel *bool `json:"full_tunnel"`
IsActive *bool `json:"is_active"` IsActive *bool `json:"is_active"`
Destinations []string `json:"destinations"` Destinations []string `json:"destinations"`
ServiceIDs []uuid.UUID `json:"service_ids"`
Targets []Target `json:"targets"` Targets []Target `json:"targets"`
} }
@@ -47,4 +60,5 @@ type SelectableProfile struct {
Description string `json:"description"` Description string `json:"description"`
FullTunnel bool `json:"full_tunnel"` FullTunnel bool `json:"full_tunnel"`
Destinations []string `json:"destinations"` Destinations []string `json:"destinations"`
Services []PolicyService `json:"services"`
} }

View File

@@ -0,0 +1,78 @@
package servicecatalog
import (
"encoding/json"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"nexavpn/backend/internal/apiutil"
)
type Handler struct {
service *CatalogService
}
func NewHandler(service *CatalogService) *Handler {
return &Handler{service: service}
}
func (h *Handler) List(w http.ResponseWriter, r *http.Request) {
items, err := h.service.List(r.Context())
if err != nil {
apiutil.Error(w, http.StatusInternalServerError, "services_list_failed", "unable to list services")
return
}
apiutil.JSON(w, http.StatusOK, items)
}
func (h *Handler) Create(w http.ResponseWriter, r *http.Request) {
var input CreateRequest
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
return
}
item, err := h.service.Create(r.Context(), input)
if err != nil {
apiutil.Error(w, http.StatusInternalServerError, "service_create_failed", "unable to create service")
return
}
apiutil.JSON(w, http.StatusCreated, item)
}
func (h *Handler) Update(w http.ResponseWriter, r *http.Request) {
serviceID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
apiutil.Error(w, http.StatusBadRequest, "invalid_service_id", "invalid service id")
return
}
var input UpdateRequest
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
return
}
item, err := h.service.Update(r.Context(), serviceID, input)
if err != nil {
apiutil.Error(w, http.StatusInternalServerError, "service_update_failed", "unable to update service")
return
}
apiutil.JSON(w, http.StatusOK, item)
}
func (h *Handler) Delete(w http.ResponseWriter, r *http.Request) {
serviceID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
apiutil.Error(w, http.StatusBadRequest, "invalid_service_id", "invalid service id")
return
}
if err := h.service.Delete(r.Context(), serviceID); err != nil {
apiutil.Error(w, http.StatusInternalServerError, "service_delete_failed", "unable to delete service")
return
}
apiutil.JSON(w, http.StatusOK, map[string]any{"ok": true})
}

View File

@@ -0,0 +1,152 @@
package servicecatalog
import (
"context"
"errors"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
type Repository interface {
List(ctx context.Context) ([]Service, error)
Create(ctx context.Context, input CreateRequest) (Service, error)
Update(ctx context.Context, serviceID uuid.UUID, input UpdateRequest) (Service, error)
Delete(ctx context.Context, serviceID uuid.UUID) error
ByIDs(ctx context.Context, ids []uuid.UUID) ([]Service, error)
}
type PGRepository struct {
db *pgxpool.Pool
}
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
return &PGRepository{db: db}
}
func (r *PGRepository) List(ctx context.Context) ([]Service, error) {
rows, err := r.db.Query(ctx, `
select id, name, description, domain, host(upstream_ip), host(proxy_ip), ports, is_active
from services
where deleted_at is null
order by name asc
`)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Service
for rows.Next() {
var item Service
if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports, &item.IsActive); err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (r *PGRepository) Create(ctx context.Context, input CreateRequest) (Service, error) {
row := r.db.QueryRow(ctx, `
insert into services (id, name, description, domain, upstream_ip, proxy_ip, ports, is_active)
values ($1, $2, $3, $4, $5::inet, $6::inet, $7::integer[], coalesce($8, true))
returning id, name, description, domain, host(upstream_ip), host(proxy_ip), ports, is_active
`, uuid.New(), input.Name, input.Description, input.Domain, input.UpstreamIP, input.ProxyIP, normalizePorts(input.Ports), input.IsActive)
var item Service
err := row.Scan(&item.ID, &item.Name, &item.Description, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports, &item.IsActive)
return item, err
}
func (r *PGRepository) Update(ctx context.Context, serviceID uuid.UUID, input UpdateRequest) (Service, error) {
var ports *[]int
if input.Ports != nil {
normalized := normalizePorts(input.Ports)
ports = &normalized
}
row := r.db.QueryRow(ctx, `
update services
set
name = coalesce($2, name),
description = coalesce($3, description),
domain = coalesce($4, domain),
upstream_ip = coalesce($5::inet, upstream_ip),
proxy_ip = coalesce($6::inet, proxy_ip),
ports = coalesce($7::integer[], ports),
is_active = coalesce($8, is_active),
updated_at = now()
where id = $1 and deleted_at is null
returning id, name, description, domain, host(upstream_ip), host(proxy_ip), ports, is_active
`, serviceID, input.Name, input.Description, input.Domain, input.UpstreamIP, input.ProxyIP, ports, input.IsActive)
var item Service
err := row.Scan(&item.ID, &item.Name, &item.Description, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports, &item.IsActive)
return item, err
}
func (r *PGRepository) Delete(ctx context.Context, serviceID uuid.UUID) error {
_, err := r.db.Exec(ctx, `
update services
set deleted_at = now(), updated_at = now()
where id = $1 and deleted_at is null
`, serviceID)
return err
}
func (r *PGRepository) ByIDs(ctx context.Context, ids []uuid.UUID) ([]Service, error) {
if len(ids) == 0 {
return nil, nil
}
rows, err := r.db.Query(ctx, `
select id, name, description, domain, host(upstream_ip), host(proxy_ip), ports, is_active
from services
where deleted_at is null and id = any($1::uuid[])
order by name asc
`, ids)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Service
for rows.Next() {
var item Service
if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.Domain, &item.UpstreamIP, &item.ProxyIP, &item.Ports, &item.IsActive); err != nil {
return nil, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(items) == 0 {
return nil, errors.New("services not found")
}
return items, nil
}
func normalizePorts(ports []int) []int {
if len(ports) == 0 {
return []int{80, 443}
}
seen := make(map[int]struct{}, len(ports))
result := make([]int, 0, len(ports))
for _, port := range ports {
if port <= 0 {
continue
}
if _, ok := seen[port]; ok {
continue
}
seen[port] = struct{}{}
result = append(result, port)
}
if len(result) == 0 {
return []int{80, 443}
}
return result
}

View File

@@ -0,0 +1,35 @@
package servicecatalog
import (
"context"
"github.com/google/uuid"
)
type CatalogService struct {
repo Repository
}
func NewService(repo Repository) *CatalogService {
return &CatalogService{repo: repo}
}
func (s *CatalogService) List(ctx context.Context) ([]Service, error) {
return s.repo.List(ctx)
}
func (s *CatalogService) Create(ctx context.Context, input CreateRequest) (Service, error) {
return s.repo.Create(ctx, input)
}
func (s *CatalogService) Update(ctx context.Context, serviceID uuid.UUID, input UpdateRequest) (Service, error) {
return s.repo.Update(ctx, serviceID, input)
}
func (s *CatalogService) Delete(ctx context.Context, serviceID uuid.UUID) error {
return s.repo.Delete(ctx, serviceID)
}
func (s *CatalogService) ByIDs(ctx context.Context, ids []uuid.UUID) ([]Service, error) {
return s.repo.ByIDs(ctx, ids)
}

View File

@@ -0,0 +1,34 @@
package servicecatalog
import "github.com/google/uuid"
type Service struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Domain string `json:"domain"`
UpstreamIP string `json:"upstream_ip"`
ProxyIP string `json:"proxy_ip"`
Ports []int `json:"ports"`
IsActive bool `json:"is_active"`
}
type CreateRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Domain string `json:"domain"`
UpstreamIP string `json:"upstream_ip"`
ProxyIP string `json:"proxy_ip"`
Ports []int `json:"ports"`
IsActive *bool `json:"is_active,omitempty"`
}
type UpdateRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Domain *string `json:"domain"`
UpstreamIP *string `json:"upstream_ip"`
ProxyIP *string `json:"proxy_ip"`
Ports []int `json:"ports"`
IsActive *bool `json:"is_active"`
}

View File

@@ -6,7 +6,16 @@ type Peer struct {
AssignedIP string `json:"assigned_ip"` AssignedIP string `json:"assigned_ip"`
AllowedDestinations []string `json:"allowed_destinations"` AllowedDestinations []string `json:"allowed_destinations"`
DNSServers []string `json:"dns_servers"` DNSServers []string `json:"dns_servers"`
WebProxyTargets []string `json:"web_proxy_targets"` AllowedServices []AllowedService `json:"allowed_services"`
}
type AllowedService struct {
Name string `json:"name"`
Domain string `json:"domain"`
UpstreamIP string `json:"upstream_ip"`
ProxyIP string `json:"proxy_ip"`
AccessProxyIP string `json:"access_proxy_ip"`
Ports []int `json:"ports"`
} }
type GatewayBundle struct { type GatewayBundle struct {

View File

@@ -24,4 +24,6 @@ NEXAVPN_GATEWAY_INTERFACE=wg0
NEXAVPN_UPLINK_INTERFACE=eth0 NEXAVPN_UPLINK_INTERFACE=eth0
NEXAVPN_ENABLE_MASQUERADE=true NEXAVPN_ENABLE_MASQUERADE=true
NEXAVPN_BACKEND_HOST=127.0.0.1 NEXAVPN_BACKEND_HOST=127.0.0.1
NEXAVPN_ALWAYS_ALLOW_WEB_PROXY_IPS=172.16.0.109 NEXAVPN_ACCESS_PROXY_IP=172.16.0.120
NEXAVPN_ACCESS_PROXY_HTTP_ADDR=172.16.0.120:80
NEXAVPN_ACCESS_PROXY_HTTPS_ADDR=172.16.0.120:443

View File

@@ -0,0 +1,10 @@
FROM golang:1.23-alpine AS builder
WORKDIR /src
COPY access-proxy/ ./
RUN go build -o /out/nexavpn-access-proxy ./main.go
FROM alpine:3.21
RUN apk add --no-cache ca-certificates
COPY --from=builder /out/nexavpn-access-proxy /usr/local/bin/nexavpn-access-proxy
ENTRYPOINT ["nexavpn-access-proxy"]

View File

@@ -0,0 +1,3 @@
module nexavpn/access-proxy
go 1.23

328
deploy/access-proxy/main.go Normal file
View File

@@ -0,0 +1,328 @@
package main
import (
"bufio"
"context"
"encoding/json"
"errors"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
type syncBundle struct {
Peers []peerConfig `json:"peers"`
}
type peerConfig struct {
AssignedIP string `json:"assigned_ip"`
AllowedServices []allowedService `json:"allowed_services"`
}
type allowedService struct {
Name string `json:"name"`
Domain string `json:"domain"`
ProxyIP string `json:"proxy_ip"`
AccessProxyIP string `json:"access_proxy_ip"`
}
type proxyState struct {
mu sync.RWMutex
allowed map[string]map[string]allowedService
}
func main() {
state := &proxyState{
allowed: make(map[string]map[string]allowedService),
}
ctx := context.Background()
if err := refreshConfig(ctx, state); err != nil {
log.Printf("initial config refresh failed: %v", err)
}
go func() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for range ticker.C {
if err := refreshConfig(ctx, state); err != nil {
log.Printf("config refresh failed: %v", err)
}
}
}()
httpAddr := envOrDefault("NEXAVPN_ACCESS_PROXY_HTTP_ADDR", ":8088")
httpsAddr := envOrDefault("NEXAVPN_ACCESS_PROXY_HTTPS_ADDR", ":8448")
go func() {
log.Printf("HTTP access proxy listening on %s", httpAddr)
if err := http.ListenAndServe(httpAddr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleHTTP(state, w, r)
})); err != nil {
log.Fatalf("http server failed: %v", err)
}
}()
log.Printf("HTTPS access proxy listening on %s", httpsAddr)
if err := serveTLSProxy(httpsAddr, state); err != nil {
log.Fatalf("https proxy failed: %v", err)
}
}
func handleHTTP(state *proxyState, w http.ResponseWriter, r *http.Request) {
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, "invalid client address", http.StatusForbidden)
return
}
host := normalizeHost(r.Host)
service, ok := state.lookup(clientIP, host)
if !ok {
http.Error(w, "service not allowed", http.StatusForbidden)
return
}
targetURL := &url.URL{
Scheme: "http",
Host: net.JoinHostPort(service.ProxyIP, "80"),
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = host
req.Header.Set("Host", host)
req.Header.Set("X-Forwarded-Host", host)
}
proxy.ErrorHandler = func(rw http.ResponseWriter, _ *http.Request, proxyErr error) {
http.Error(rw, proxyErr.Error(), http.StatusBadGateway)
}
proxy.ServeHTTP(w, r)
}
func serveTLSProxy(addr string, state *proxyState) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go handleTLSConn(conn, state)
}
}
func handleTLSConn(clientConn net.Conn, state *proxyState) {
defer clientConn.Close()
clientIP, _, err := net.SplitHostPort(clientConn.RemoteAddr().String())
if err != nil {
return
}
reader := bufio.NewReader(clientConn)
hello, host, err := readClientHello(reader)
if err != nil {
return
}
service, ok := state.lookup(clientIP, host)
if !ok {
return
}
upstreamConn, err := net.DialTimeout("tcp", net.JoinHostPort(service.ProxyIP, "443"), 10*time.Second)
if err != nil {
return
}
defer upstreamConn.Close()
if _, err := upstreamConn.Write(hello); err != nil {
return
}
errCh := make(chan error, 2)
go proxyCopy(errCh, upstreamConn, reader)
go proxyCopy(errCh, clientConn, upstreamConn)
<-errCh
}
func proxyCopy(errCh chan<- error, dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errCh <- err
}
func readClientHello(reader *bufio.Reader) ([]byte, string, error) {
header, err := reader.Peek(5)
if err != nil {
return nil, "", err
}
if header[0] != 22 {
return nil, "", errors.New("not a tls client hello")
}
recordLen := int(header[3])<<8 | int(header[4])
full, err := reader.Peek(5 + recordLen)
if err != nil {
return nil, "", err
}
host, err := extractSNI(full)
if err != nil {
return nil, "", err
}
return append([]byte(nil), full...), host, nil
}
func extractSNI(packet []byte) (string, error) {
if len(packet) < 43 {
return "", errors.New("tls packet too short")
}
sessionIDLen := int(packet[43])
offset := 44 + sessionIDLen
if len(packet) < offset+2 {
return "", errors.New("missing cipher suites")
}
cipherLen := int(packet[offset])<<8 | int(packet[offset+1])
offset += 2 + cipherLen
if len(packet) < offset+1 {
return "", errors.New("missing compression methods")
}
compressionLen := int(packet[offset])
offset += 1 + compressionLen
if len(packet) < offset+2 {
return "", errors.New("missing extensions")
}
extensionsLen := int(packet[offset])<<8 | int(packet[offset+1])
offset += 2
end := offset + extensionsLen
if len(packet) < end {
return "", errors.New("invalid extensions length")
}
for offset+4 <= end {
extensionType := int(packet[offset])<<8 | int(packet[offset+1])
extensionLen := int(packet[offset+2])<<8 | int(packet[offset+3])
offset += 4
if offset+extensionLen > end {
return "", errors.New("invalid extension")
}
if extensionType == 0 {
if extensionLen < 5 {
return "", errors.New("invalid sni extension")
}
serverNameLen := int(packet[offset+3])<<8 | int(packet[offset+4])
if offset+5+serverNameLen > end {
return "", errors.New("invalid sni length")
}
return normalizeHost(string(packet[offset+5 : offset+5+serverNameLen])), nil
}
offset += extensionLen
}
return "", errors.New("sni not found")
}
func refreshConfig(ctx context.Context, state *proxyState) error {
gatewayID, err := resolveGatewayID()
if err != nil {
return err
}
syncURL := strings.TrimRight(envOrDefault("NEXAVPN_GATEWAY_SYNC_URL", "http://127.0.0.1:8080/api/v1/gateway-agent"), "/") + "/" + gatewayID + "/sync"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, syncURL, nil)
if err != nil {
return err
}
req.Header.Set("X-Gateway-Bootstrap-Token", os.Getenv("GATEWAY_BOOTSTRAP_TOKEN"))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return errors.New("sync request failed with status " + resp.Status)
}
var bundle syncBundle
if err := json.NewDecoder(resp.Body).Decode(&bundle); err != nil {
return err
}
allowed := make(map[string]map[string]allowedService)
for _, peer := range bundle.Peers {
hostMap := make(map[string]allowedService)
for _, service := range peer.AllowedServices {
hostMap[normalizeHost(service.Domain)] = service
}
allowed[stripCIDR(peer.AssignedIP)] = hostMap
}
state.mu.Lock()
state.allowed = allowed
state.mu.Unlock()
return nil
}
func (s *proxyState) lookup(clientIP string, host string) (allowedService, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
services, ok := s.allowed[clientIP]
if !ok {
return allowedService{}, false
}
service, ok := services[normalizeHost(host)]
return service, ok
}
func resolveGatewayID() (string, error) {
if value := strings.TrimSpace(os.Getenv("NEXAVPN_GATEWAY_ID")); value != "" {
return value, nil
}
stateFile := envOrDefault("NEXAVPN_GATEWAY_ID_FILE", "/var/lib/nexavpn/gateway-id")
raw, err := os.ReadFile(filepath.Clean(stateFile))
if err != nil {
return "", err
}
return strings.TrimSpace(string(raw)), nil
}
func envOrDefault(key string, fallback string) string {
if value := strings.TrimSpace(os.Getenv(key)); value != "" {
return value
}
return fallback
}
func stripCIDR(value string) string {
if index := strings.IndexByte(value, '/'); index >= 0 {
return value[:index]
}
return value
}
func normalizeHost(host string) string {
host = strings.TrimSpace(strings.ToLower(host))
host = strings.TrimSuffix(host, ".")
if strings.Contains(host, ":") {
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
host = parsedHost
}
}
return host
}

View File

@@ -89,10 +89,28 @@ services:
NEXAVPN_UPLINK_INTERFACE: ${NEXAVPN_UPLINK_INTERFACE:-eth0} NEXAVPN_UPLINK_INTERFACE: ${NEXAVPN_UPLINK_INTERFACE:-eth0}
NEXAVPN_ENABLE_MASQUERADE: ${NEXAVPN_ENABLE_MASQUERADE:-true} NEXAVPN_ENABLE_MASQUERADE: ${NEXAVPN_ENABLE_MASQUERADE:-true}
NEXAVPN_BACKEND_HOST: ${NEXAVPN_BACKEND_HOST:-127.0.0.1} NEXAVPN_BACKEND_HOST: ${NEXAVPN_BACKEND_HOST:-127.0.0.1}
NEXAVPN_ACCESS_PROXY_IP: ${NEXAVPN_ACCESS_PROXY_IP:-}
volumes: volumes:
- ./scripts/gateway-entrypoint.sh:/scripts/gateway-entrypoint.sh:ro - ./scripts/gateway-entrypoint.sh:/scripts/gateway-entrypoint.sh:ro
- gateway-state:/var/lib/nexavpn - gateway-state:/var/lib/nexavpn
access-proxy:
build:
context: .
dockerfile: access-proxy/Dockerfile
depends_on:
- backend
network_mode: host
environment:
GATEWAY_BOOTSTRAP_TOKEN: ${GATEWAY_BOOTSTRAP_TOKEN:-nexavpn-gateway-bootstrap}
NEXAVPN_GATEWAY_ID: ${NEXAVPN_GATEWAY_ID:-}
NEXAVPN_GATEWAY_ID_FILE: /var/lib/nexavpn/gateway-id
NEXAVPN_GATEWAY_SYNC_URL: ${NEXAVPN_GATEWAY_SYNC_URL:-http://127.0.0.1:8080/api/v1/gateway-agent}
NEXAVPN_ACCESS_PROXY_HTTP_ADDR: ${NEXAVPN_ACCESS_PROXY_HTTP_ADDR:-172.16.0.120:80}
NEXAVPN_ACCESS_PROXY_HTTPS_ADDR: ${NEXAVPN_ACCESS_PROXY_HTTPS_ADDR:-172.16.0.120:443}
volumes:
- gateway-state:/var/lib/nexavpn
volumes: volumes:
postgres-data: postgres-data:
gateway-state: gateway-state:

View File

@@ -118,9 +118,11 @@ EOF
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} udp dport 53 accept" echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} udp dport 53 accept"
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} tcp dport 53 accept" echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${dns_server} tcp dport 53 accept"
done done
printf '%s' "${peer}" | jq -r '.web_proxy_targets[]?' | while read -r proxy_target; do printf '%s' "${peer}" | jq -c '.allowed_services[]?' | while read -r service; do
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${proxy_target} tcp dport 80 accept" SERVICE_PROXY_IP="$(printf '%s' "${service}" | jq -r '.access_proxy_ip')"
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${proxy_target} tcp dport 443 accept" printf '%s' "${service}" | jq -r '.ports[]?' | while read -r service_port; do
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${SERVICE_PROXY_IP} tcp dport ${service_port} accept"
done
done done
printf '%s' "${peer}" | jq -r '.allowed_destinations[]?' | while read -r destination; do printf '%s' "${peer}" | jq -r '.allowed_destinations[]?' | while read -r destination; do
echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${destination} accept" echo " iifname \"${IFACE}\" ip saddr ${ASSIGNED_IP} ip daddr ${destination} accept"

View File

@@ -1,6 +1,15 @@
mod tunnel_manager; mod tunnel_manager;
use std::{fs, io::Cursor, net::TcpListener, path::PathBuf, sync::Mutex}; use std::{
fs,
io::Cursor,
net::TcpListener,
path::PathBuf,
sync::{
atomic::{AtomicBool, Ordering},
Mutex,
},
};
use base64::{engine::general_purpose::STANDARD, Engine as _}; use base64::{engine::general_purpose::STANDARD, Engine as _};
use png::{ColorType, Decoder}; use png::{ColorType, Decoder};
@@ -22,6 +31,7 @@ const SINGLE_INSTANCE_ADDR: &str = "127.0.0.1:53190";
struct AppState { struct AppState {
session: Mutex<Option<SessionState>>, session: Mutex<Option<SessionState>>,
tray: Mutex<Option<TrayState>>, tray: Mutex<Option<TrayState>>,
tunnel_action_in_progress: AtomicBool,
single_instance_lock: TcpListener, single_instance_lock: TcpListener,
} }
@@ -392,14 +402,29 @@ async fn select_access_profile(app: AppHandle, profile_id: String) -> Result<Enr
#[tauri::command] #[tauri::command]
async fn connect_tunnel(app: AppHandle) -> Result<EnrollmentResult, String> { async fn connect_tunnel(app: AppHandle) -> Result<EnrollmentResult, String> {
let session_state = sync_current_session(&app).await?; let state = app.state::<AppState>();
state.tunnel_action_in_progress.store(true, Ordering::SeqCst);
let session_state = match sync_current_session(&app).await {
Ok(value) => value,
Err(err) => {
state.tunnel_action_in_progress.store(false, Ordering::SeqCst);
return Err(err);
}
};
let app_handle = app.clone(); let app_handle = app.clone();
let profile_path = session_state.profile_path.clone(); let profile_path = session_state.profile_path.clone();
let result = tauri::async_runtime::spawn_blocking(move || { let result = match tauri::async_runtime::spawn_blocking(move || {
tunnel_manager::connect(&app_handle, std::path::Path::new(&profile_path)) tunnel_manager::connect(&app_handle, std::path::Path::new(&profile_path))
}) })
.await .await
.map_err(|err| format!("Unable to join tunnel connect task: {err}"))?; .map_err(|err| format!("Unable to join tunnel connect task: {err}")) {
Ok(value) => value,
Err(err) => {
state.tunnel_action_in_progress.store(false, Ordering::SeqCst);
return Err(err);
}
};
state.tunnel_action_in_progress.store(false, Ordering::SeqCst);
refresh_tray_menu(&app); refresh_tray_menu(&app);
result?; result?;
Ok(session_state.enrollment) Ok(session_state.enrollment)
@@ -407,39 +432,65 @@ async fn connect_tunnel(app: AppHandle) -> Result<EnrollmentResult, String> {
#[tauri::command] #[tauri::command]
async fn disconnect_tunnel(app: AppHandle, state: State<'_, AppState>) -> Result<(), String> { async fn disconnect_tunnel(app: AppHandle, state: State<'_, AppState>) -> Result<(), String> {
state.tunnel_action_in_progress.store(true, Ordering::SeqCst);
let profile_path = { let profile_path = {
let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?; let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?;
let session = session.as_ref().ok_or_else(|| "No active session is available".to_string())?; let session = session.as_ref().ok_or_else(|| "No active session is available".to_string())?;
session.profile_path.clone() session.profile_path.clone()
}; };
let app_handle = app.clone(); let app_handle = app.clone();
let result = tauri::async_runtime::spawn_blocking(move || { let result = match tauri::async_runtime::spawn_blocking(move || {
tunnel_manager::disconnect(&app_handle, std::path::Path::new(&profile_path)) tunnel_manager::disconnect(&app_handle, std::path::Path::new(&profile_path))
}) })
.await .await
.map_err(|err| format!("Unable to join tunnel disconnect task: {err}"))?; .map_err(|err| format!("Unable to join tunnel disconnect task: {err}")) {
Ok(value) => value,
Err(err) => {
state.tunnel_action_in_progress.store(false, Ordering::SeqCst);
return Err(err);
}
};
state.tunnel_action_in_progress.store(false, Ordering::SeqCst);
refresh_tray_menu(&app); refresh_tray_menu(&app);
result result
} }
#[tauri::command] #[tauri::command]
fn tunnel_status(app: AppHandle, state: State<'_, AppState>) -> Result<bool, String> { async fn tunnel_status(app: AppHandle, state: State<'_, AppState>) -> Result<bool, String> {
if state.tunnel_action_in_progress.load(Ordering::SeqCst) {
return Ok(false);
}
let profile_path = { let profile_path = {
let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?; let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?;
let session = session.as_ref().ok_or_else(|| "No active session is available".to_string())?; let session = session.as_ref().ok_or_else(|| "No active session is available".to_string())?;
session.profile_path.clone() session.profile_path.clone()
}; };
tauri::async_runtime::spawn_blocking(move || {
tunnel_manager::is_active(&app, std::path::Path::new(&profile_path)) tunnel_manager::is_active(&app, std::path::Path::new(&profile_path))
})
.await
.map_err(|err| format!("Unable to join tunnel status task: {err}"))?
} }
#[tauri::command] #[tauri::command]
fn tunnel_metrics(app: AppHandle, state: State<'_, AppState>) -> Result<TunnelMetrics, String> { async fn tunnel_metrics(app: AppHandle, state: State<'_, AppState>) -> Result<TunnelMetrics, String> {
if state.tunnel_action_in_progress.load(Ordering::SeqCst) {
return Ok(TunnelMetrics {
active: false,
rx_bytes: 0,
tx_bytes: 0,
});
}
let profile_path = { let profile_path = {
let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?; let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?;
let session = session.as_ref().ok_or_else(|| "No active session is available".to_string())?; let session = session.as_ref().ok_or_else(|| "No active session is available".to_string())?;
session.profile_path.clone() session.profile_path.clone()
}; };
let metrics = tunnel_manager::metrics(&app, std::path::Path::new(&profile_path))?; let metrics = tauri::async_runtime::spawn_blocking(move || {
tunnel_manager::metrics(&app, std::path::Path::new(&profile_path))
})
.await
.map_err(|err| format!("Unable to join tunnel metrics task: {err}"))??;
let mapped = TunnelMetrics { let mapped = TunnelMetrics {
active: metrics.active, active: metrics.active,
rx_bytes: metrics.rx_bytes, rx_bytes: metrics.rx_bytes,
@@ -542,6 +593,13 @@ fn format_data_size(bytes: u64) -> String {
fn current_metrics(app: &AppHandle) -> Result<TunnelMetrics, String> { fn current_metrics(app: &AppHandle) -> Result<TunnelMetrics, String> {
let state = app.state::<AppState>(); let state = app.state::<AppState>();
if state.tunnel_action_in_progress.load(Ordering::SeqCst) {
return Ok(TunnelMetrics {
active: false,
rx_bytes: 0,
tx_bytes: 0,
});
}
let profile_path = { let profile_path = {
let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?; let session = state.session.lock().map_err(|_| "Unable to read client state".to_string())?;
let session = session.as_ref().ok_or_else(|| "No active session is available".to_string())?; let session = session.as_ref().ok_or_else(|| "No active session is available".to_string())?;
@@ -923,6 +981,7 @@ pub fn run() {
sent_item, sent_item,
toggle_item, toggle_item,
})), })),
tunnel_action_in_progress: AtomicBool::new(false),
single_instance_lock, single_instance_lock,
}); });
refresh_tray_menu(app.handle()); refresh_tray_menu(app.handle());