chore: initial project scaffold with admin web, backend, desktop client, and deployment setup
Add monorepo structure for NexaVPN WireGuard control plane including: - .gitignore for node_modules, build artifacts, and environment files - README with project overview, monorepo layout, and quick start guide - Admin web UI with React, Vite, TypeScript, and nginx reverse proxy - API client with type definitions for users, devices, policies, gateways, and audit logs - Admin pages for dashboard, users, devices, policies, g
This commit is contained in:
14
backend/Dockerfile
Normal file
14
backend/Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
||||
FROM golang:1.23-alpine AS builder
|
||||
WORKDIR /src
|
||||
COPY go.mod ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /out/nexavpn-api ./cmd/api
|
||||
|
||||
FROM alpine:3.21
|
||||
WORKDIR /app
|
||||
COPY --from=builder /out/nexavpn-api /usr/local/bin/nexavpn-api
|
||||
COPY migrations ./migrations
|
||||
COPY seed ./seed
|
||||
EXPOSE 8080
|
||||
ENTRYPOINT ["/usr/local/bin/nexavpn-api"]
|
||||
48
backend/cmd/api/main.go
Normal file
48
backend/cmd/api/main.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/app"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/config"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := config.Load()
|
||||
|
||||
application, err := app.New(cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize app: %v", err)
|
||||
}
|
||||
defer application.Close()
|
||||
|
||||
server := &http.Server{
|
||||
Addr: cfg.HTTPAddress,
|
||||
Handler: application.Router,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("nexavpn backend listening on %s", cfg.HTTPAddress)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("http server failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-stop
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
log.Printf("server shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
11
backend/go.mod
Normal file
11
backend/go.mod
Normal file
@@ -0,0 +1,11 @@
|
||||
module github.com/nexavpn/nexavpn/backend
|
||||
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.2.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.2
|
||||
golang.org/x/crypto v0.36.0
|
||||
)
|
||||
26
backend/internal/apiutil/respond.go
Normal file
26
backend/internal/apiutil/respond.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package apiutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func JSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func Error(w http.ResponseWriter, status int, code, message string) {
|
||||
resp := ErrorResponse{}
|
||||
resp.Error.Code = code
|
||||
resp.Error.Message = message
|
||||
JSON(w, status, resp)
|
||||
}
|
||||
62
backend/internal/app/app.go
Normal file
62
backend/internal/app/app.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/auth"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/config"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/db"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/device"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/gateway"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/ipam"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/policy"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/user"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
DB *pgxpool.Pool
|
||||
Router http.Handler
|
||||
}
|
||||
|
||||
func New(cfg config.Config) (*App, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
pool, err := db.Connect(ctx, cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authRepo := auth.NewPGRepository(pool)
|
||||
authService := auth.NewService(authRepo, cfg.JWTSecret, cfg.JWTIssuer, cfg.AccessTokenTTL, cfg.RefreshTokenTTL)
|
||||
|
||||
userService := user.NewService(user.NewPGRepository(pool))
|
||||
policyService := policy.NewService(policy.NewPGRepository(pool))
|
||||
gatewayService := gateway.NewService(gateway.NewPGRepository(pool))
|
||||
deviceService := device.NewService(device.NewPGRepository(pool), policyService, gatewayService, ipam.NewService())
|
||||
auditService := audit.NewService(audit.NewPGRepository(pool))
|
||||
|
||||
router := httpserver.NewRouter(cfg.JWTSecret, httpserver.Handlers{
|
||||
Auth: auth.NewHandler(authService, auditService),
|
||||
User: user.NewHandler(userService, auditService),
|
||||
Device: device.NewHandler(deviceService, auditService),
|
||||
Policy: policy.NewHandler(policyService, auditService),
|
||||
Gateway: gateway.NewHandler(gatewayService),
|
||||
Audit: audit.NewHandler(auditService),
|
||||
})
|
||||
|
||||
return &App{
|
||||
DB: pool,
|
||||
Router: router,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *App) Close() {
|
||||
if a.DB != nil {
|
||||
a.DB.Close()
|
||||
}
|
||||
}
|
||||
25
backend/internal/audit/handler.go
Normal file
25
backend/internal/audit/handler.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewHandler(service *Service) *Handler {
|
||||
return &Handler{service: service}
|
||||
}
|
||||
|
||||
func (h *Handler) List(w http.ResponseWriter, r *http.Request) {
|
||||
items, err := h.service.List(r.Context(), 100)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "audit_list_failed", "unable to list audit logs")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, items)
|
||||
}
|
||||
70
backend/internal/audit/repository.go
Normal file
70
backend/internal/audit/repository.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type PGRepository struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
||||
return &PGRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *PGRepository) Write(ctx context.Context, entry Entry) error {
|
||||
const query = `
|
||||
insert into audit_logs (id, actor_user_id, entity_type, entity_id, event_type, status, message, metadata)
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8::jsonb)
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(
|
||||
ctx,
|
||||
query,
|
||||
uuid.New(),
|
||||
entry.ActorUserID,
|
||||
entry.EntityType,
|
||||
entry.EntityID,
|
||||
entry.EventType,
|
||||
entry.Status,
|
||||
entry.Message,
|
||||
MarshalMetadata(entry.Metadata),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PGRepository) List(ctx context.Context, limit int) ([]map[string]any, error) {
|
||||
rows, err := r.db.Query(ctx, `
|
||||
select id, event_type, entity_type, status, message, created_at
|
||||
from audit_logs
|
||||
order by created_at desc
|
||||
limit $1
|
||||
`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []map[string]any
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
var eventType, entityType, status, message string
|
||||
var createdAt any
|
||||
if err := rows.Scan(&id, &eventType, &entityType, &status, &message, &createdAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries = append(entries, map[string]any{
|
||||
"id": id,
|
||||
"event_type": eventType,
|
||||
"entity_type": entityType,
|
||||
"status": status,
|
||||
"message": message,
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
47
backend/internal/audit/service.go
Normal file
47
backend/internal/audit/service.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
ActorUserID *uuid.UUID
|
||||
EntityType string
|
||||
EntityID *uuid.UUID
|
||||
EventType string
|
||||
Status string
|
||||
Message string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
type Repository interface {
|
||||
Write(ctx context.Context, entry Entry) error
|
||||
List(ctx context.Context, limit int) ([]map[string]any, error)
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
repo Repository
|
||||
}
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repo: repo}
|
||||
}
|
||||
|
||||
func (s *Service) Record(ctx context.Context, entry Entry) error {
|
||||
if entry.Metadata == nil {
|
||||
entry.Metadata = map[string]any{}
|
||||
}
|
||||
return s.repo.Write(ctx, entry)
|
||||
}
|
||||
|
||||
func (s *Service) List(ctx context.Context, limit int) ([]map[string]any, error) {
|
||||
return s.repo.List(ctx, limit)
|
||||
}
|
||||
|
||||
func MarshalMetadata(metadata map[string]any) []byte {
|
||||
raw, _ := json.Marshal(metadata)
|
||||
return raw
|
||||
}
|
||||
137
backend/internal/auth/handler.go
Normal file
137
backend/internal/auth/handler.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
func NewHandler(service *Service, auditService *audit.Service) *Handler {
|
||||
return &Handler{service: service, audit: auditService}
|
||||
}
|
||||
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var input LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.service.Login(r.Context(), input.Username, input.Password, r.RemoteAddr, r.UserAgent())
|
||||
if err != nil {
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
EventType: "auth.login.failed",
|
||||
EntityType: "user",
|
||||
Status: "failed",
|
||||
Message: "user login failed",
|
||||
Metadata: map[string]any{
|
||||
"username": input.Username,
|
||||
},
|
||||
})
|
||||
apiutil.Error(w, http.StatusUnauthorized, "invalid_credentials", "invalid username or password")
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &response.User.ID,
|
||||
EventType: "auth.login",
|
||||
EntityType: "user",
|
||||
EntityID: &response.User.ID,
|
||||
Status: "success",
|
||||
Message: "user login succeeded",
|
||||
})
|
||||
apiutil.JSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) Bootstrap(w http.ResponseWriter, r *http.Request) {
|
||||
var input BootstrapRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
|
||||
return
|
||||
}
|
||||
if input.Username == "" || input.Password == "" {
|
||||
apiutil.Error(w, http.StatusBadRequest, "validation_error", "username and password are required")
|
||||
return
|
||||
}
|
||||
if input.DisplayName == "" {
|
||||
input.DisplayName = input.Username
|
||||
}
|
||||
|
||||
user, err := h.service.BootstrapAdmin(r.Context(), input.Username, input.DisplayName, input.Password)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusConflict, "bootstrap_failed", "initial admin already exists or could not be created")
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &user.ID,
|
||||
EntityType: "user",
|
||||
EntityID: &user.ID,
|
||||
EventType: "system.bootstrap_admin",
|
||||
Status: "success",
|
||||
Message: "initial admin account created",
|
||||
})
|
||||
apiutil.JSON(w, http.StatusCreated, user)
|
||||
}
|
||||
|
||||
func (h *Handler) Refresh(w http.ResponseWriter, r *http.Request) {
|
||||
var input RefreshRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.service.Refresh(r.Context(), input.RefreshToken)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "invalid_refresh_token", "unable to refresh session")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
var input RefreshRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.Logout(r.Context(), input.RefreshToken); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "logout_failed", "unable to revoke session")
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &claims.UserID,
|
||||
EventType: "auth.logout",
|
||||
EntityType: "session",
|
||||
Status: "success",
|
||||
Message: "session logout succeeded",
|
||||
})
|
||||
}
|
||||
apiutil.JSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func (h *Handler) Me(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := httpserver.ClaimsFromContext(r.Context())
|
||||
if !ok {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, map[string]any{
|
||||
"id": claims.UserID,
|
||||
"username": claims.Username,
|
||||
"role": claims.Role,
|
||||
})
|
||||
}
|
||||
11
backend/internal/auth/hash.go
Normal file
11
backend/internal/auth/hash.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
func base64Hash(value string) string {
|
||||
sum := sha256.Sum256([]byte(value))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
40
backend/internal/auth/password.go
Normal file
40
backend/internal/auth/password.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
func HashPassword(password string) (string, error) {
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
hash := argon2.IDKey([]byte(password), salt, 1, 64*1024, 4, 32)
|
||||
return fmt.Sprintf("argon2id$%s$%s", base64.RawStdEncoding.EncodeToString(salt), base64.RawStdEncoding.EncodeToString(hash)), nil
|
||||
}
|
||||
|
||||
func VerifyPassword(hashValue, password string) bool {
|
||||
parts := strings.Split(hashValue, "$")
|
||||
if len(parts) != 3 || parts[0] != "argon2id" {
|
||||
return false
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
expected, err := base64.RawStdEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
actual := argon2.IDKey([]byte(password), salt, 1, 64*1024, 4, 32)
|
||||
return subtle.ConstantTimeCompare(expected, actual) == 1
|
||||
}
|
||||
105
backend/internal/auth/repository.go
Normal file
105
backend/internal/auth/repository.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type PGRepository struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
||||
return &PGRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *PGRepository) FindUserByUsername(ctx context.Context, username string) (UserRecord, error) {
|
||||
const query = `
|
||||
select u.id, u.username, u.display_name, r.name, u.password_hash, u.is_active
|
||||
from users u
|
||||
join roles r on r.id = u.role_id
|
||||
where u.username = $1 and u.deleted_at is null
|
||||
`
|
||||
|
||||
row := r.db.QueryRow(ctx, query, username)
|
||||
record := UserRecord{}
|
||||
if err := row.Scan(&record.ID, &record.Username, &record.DisplayName, &record.Role, &record.PasswordHash, &record.IsActive); err != nil {
|
||||
return UserRecord{}, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *PGRepository) CreateSession(ctx context.Context, userID uuid.UUID, expiresAt time.Time, ipAddress string, userAgent string) (uuid.UUID, error) {
|
||||
const query = `
|
||||
insert into sessions (id, user_id, ip_address, user_agent, expires_at)
|
||||
values ($1, $2, nullif($3, '')::inet, $4, $5)
|
||||
`
|
||||
|
||||
id := uuid.New()
|
||||
_, err := r.db.Exec(ctx, query, id, userID, ipAddress, userAgent, expiresAt)
|
||||
return id, err
|
||||
}
|
||||
|
||||
func (r *PGRepository) StoreRefreshToken(ctx context.Context, sessionID uuid.UUID, userID uuid.UUID, tokenHash string, expiresAt time.Time) error {
|
||||
const query = `
|
||||
insert into refresh_tokens (id, session_id, user_id, token_hash, expires_at)
|
||||
values ($1, $2, $3, $4, $5)
|
||||
`
|
||||
|
||||
_, err := r.db.Exec(ctx, query, uuid.New(), sessionID, userID, tokenHash, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PGRepository) FindRefreshToken(ctx context.Context, tokenHash string) (UserRecord, uuid.UUID, error) {
|
||||
const query = `
|
||||
select u.id, u.username, u.display_name, roles.name, u.password_hash, u.is_active, rt.session_id
|
||||
from refresh_tokens rt
|
||||
join users u on u.id = rt.user_id
|
||||
join roles on roles.id = u.role_id
|
||||
where rt.token_hash = $1 and rt.revoked_at is null and rt.expires_at > now()
|
||||
`
|
||||
|
||||
record := UserRecord{}
|
||||
var sessionID uuid.UUID
|
||||
row := r.db.QueryRow(ctx, query, tokenHash)
|
||||
if err := row.Scan(&record.ID, &record.Username, &record.DisplayName, &record.Role, &record.PasswordHash, &record.IsActive, &sessionID); err != nil {
|
||||
return UserRecord{}, uuid.Nil, err
|
||||
}
|
||||
if !record.IsActive {
|
||||
return UserRecord{}, uuid.Nil, errors.New("user inactive")
|
||||
}
|
||||
|
||||
return record, sessionID, nil
|
||||
}
|
||||
|
||||
func (r *PGRepository) RevokeRefreshToken(ctx context.Context, tokenHash string) error {
|
||||
const query = `update refresh_tokens set revoked_at = now() where token_hash = $1 and revoked_at is null`
|
||||
_, err := r.db.Exec(ctx, query, tokenHash)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PGRepository) HasUsers(ctx context.Context) (bool, error) {
|
||||
var count int
|
||||
if err := r.db.QueryRow(ctx, `select count(*) from users where deleted_at is null`).Scan(&count); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *PGRepository) CreateBootstrapAdmin(ctx context.Context, username, displayName, passwordHash string) (UserRecord, error) {
|
||||
const query = `
|
||||
insert into users (id, role_id, username, display_name, password_hash, is_active)
|
||||
values ($1, (select id from roles where name = 'admin'), $2, $3, $4, true)
|
||||
returning id, username, display_name, password_hash, is_active
|
||||
`
|
||||
|
||||
record := UserRecord{Role: "admin"}
|
||||
err := r.db.QueryRow(ctx, query, uuid.New(), username, displayName, passwordHash).
|
||||
Scan(&record.ID, &record.Username, &record.DisplayName, &record.PasswordHash, &record.IsActive)
|
||||
return record, err
|
||||
}
|
||||
155
backend/internal/auth/service.go
Normal file
155
backend/internal/auth/service.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
|
||||
type UserRecord struct {
|
||||
ID uuid.UUID
|
||||
Username string
|
||||
DisplayName string
|
||||
Role string
|
||||
PasswordHash string
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
type Repository interface {
|
||||
FindUserByUsername(ctx context.Context, username string) (UserRecord, error)
|
||||
CreateSession(ctx context.Context, userID uuid.UUID, expiresAt time.Time, ipAddress string, userAgent string) (uuid.UUID, error)
|
||||
StoreRefreshToken(ctx context.Context, sessionID uuid.UUID, userID uuid.UUID, tokenHash string, expiresAt time.Time) error
|
||||
FindRefreshToken(ctx context.Context, tokenHash string) (UserRecord, uuid.UUID, error)
|
||||
RevokeRefreshToken(ctx context.Context, tokenHash string) error
|
||||
HasUsers(ctx context.Context) (bool, error)
|
||||
CreateBootstrapAdmin(ctx context.Context, username, displayName, passwordHash string) (UserRecord, error)
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
repo Repository
|
||||
jwtSecret string
|
||||
jwtIssuer string
|
||||
accessTokenTTL time.Duration
|
||||
refreshTokenTTL time.Duration
|
||||
}
|
||||
|
||||
func NewService(repo Repository, jwtSecret, jwtIssuer string, accessTokenTTL, refreshTokenTTL time.Duration) *Service {
|
||||
return &Service{
|
||||
repo: repo,
|
||||
jwtSecret: jwtSecret,
|
||||
jwtIssuer: jwtIssuer,
|
||||
accessTokenTTL: accessTokenTTL,
|
||||
refreshTokenTTL: refreshTokenTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Login(ctx context.Context, username, password, ipAddress, userAgent string) (LoginResponse, error) {
|
||||
record, err := s.repo.FindUserByUsername(ctx, username)
|
||||
if err != nil || !record.IsActive || !VerifyPassword(record.PasswordHash, password) {
|
||||
return LoginResponse{}, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
sessionID, err := s.repo.CreateSession(ctx, record.ID, time.Now().Add(s.refreshTokenTTL), ipAddress, userAgent)
|
||||
if err != nil {
|
||||
return LoginResponse{}, err
|
||||
}
|
||||
|
||||
plainRefresh, hashedRefresh, err := NewRefreshToken()
|
||||
if err != nil {
|
||||
return LoginResponse{}, err
|
||||
}
|
||||
|
||||
if err := s.repo.StoreRefreshToken(ctx, sessionID, record.ID, hashedRefresh, time.Now().Add(s.refreshTokenTTL)); err != nil {
|
||||
return LoginResponse{}, err
|
||||
}
|
||||
|
||||
access, err := SignAccessToken(s.jwtSecret, s.jwtIssuer, s.accessTokenTTL, Claims{
|
||||
UserID: record.ID,
|
||||
Username: record.Username,
|
||||
Role: record.Role,
|
||||
Session: sessionID,
|
||||
})
|
||||
if err != nil {
|
||||
return LoginResponse{}, err
|
||||
}
|
||||
|
||||
return LoginResponse{
|
||||
AccessToken: access,
|
||||
RefreshToken: plainRefresh,
|
||||
ExpiresIn: int64(s.accessTokenTTL.Seconds()),
|
||||
User: UserView{
|
||||
ID: record.ID,
|
||||
Username: record.Username,
|
||||
DisplayName: record.DisplayName,
|
||||
Role: record.Role,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Refresh(ctx context.Context, refreshToken string) (LoginResponse, error) {
|
||||
record, sessionID, err := s.repo.FindRefreshToken(ctx, hashToken(refreshToken))
|
||||
if err != nil {
|
||||
return LoginResponse{}, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
access, err := SignAccessToken(s.jwtSecret, s.jwtIssuer, s.accessTokenTTL, Claims{
|
||||
UserID: record.ID,
|
||||
Username: record.Username,
|
||||
Role: record.Role,
|
||||
Session: sessionID,
|
||||
})
|
||||
if err != nil {
|
||||
return LoginResponse{}, err
|
||||
}
|
||||
|
||||
return LoginResponse{
|
||||
AccessToken: access,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: int64(s.accessTokenTTL.Seconds()),
|
||||
User: UserView{
|
||||
ID: record.ID,
|
||||
Username: record.Username,
|
||||
DisplayName: record.DisplayName,
|
||||
Role: record.Role,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Logout(ctx context.Context, refreshToken string) error {
|
||||
return s.repo.RevokeRefreshToken(ctx, hashToken(refreshToken))
|
||||
}
|
||||
|
||||
func (s *Service) BootstrapAdmin(ctx context.Context, username, displayName, password string) (UserView, error) {
|
||||
hasUsers, err := s.repo.HasUsers(ctx)
|
||||
if err != nil {
|
||||
return UserView{}, err
|
||||
}
|
||||
if hasUsers {
|
||||
return UserView{}, errors.New("bootstrap already completed")
|
||||
}
|
||||
|
||||
passwordHash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
return UserView{}, err
|
||||
}
|
||||
|
||||
record, err := s.repo.CreateBootstrapAdmin(ctx, username, displayName, passwordHash)
|
||||
if err != nil {
|
||||
return UserView{}, err
|
||||
}
|
||||
|
||||
return UserView{
|
||||
ID: record.ID,
|
||||
Username: record.Username,
|
||||
DisplayName: record.DisplayName,
|
||||
Role: record.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func hashToken(plain string) string {
|
||||
return base64Hash(plain)
|
||||
}
|
||||
77
backend/internal/auth/token.go
Normal file
77
backend/internal/auth/token.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func NewRefreshToken() (plain string, hashed string, err error) {
|
||||
raw := make([]byte, 32)
|
||||
if _, err = rand.Read(raw); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
plain = base64.RawURLEncoding.EncodeToString(raw)
|
||||
sum := sha256.Sum256([]byte(plain))
|
||||
hashed = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return plain, hashed, nil
|
||||
}
|
||||
|
||||
func SignAccessToken(secret, issuer string, ttl time.Duration, claims Claims) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": claims.UserID.String(),
|
||||
"username": claims.Username,
|
||||
"role": claims.Role,
|
||||
"session_id": claims.Session.String(),
|
||||
"exp": time.Now().Add(ttl).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
func ParseAccessToken(secret string, tokenString string) (Claims, error) {
|
||||
claims := Claims{}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return claims, err
|
||||
}
|
||||
|
||||
mapClaims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return claims, jwt.ErrTokenMalformed
|
||||
}
|
||||
|
||||
subject, ok := mapClaims["sub"].(string)
|
||||
if !ok {
|
||||
return claims, jwt.ErrTokenMalformed
|
||||
}
|
||||
sessionValue, ok := mapClaims["session_id"].(string)
|
||||
if !ok {
|
||||
return claims, jwt.ErrTokenMalformed
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(subject)
|
||||
if err != nil {
|
||||
return claims, err
|
||||
}
|
||||
sessionID, err := uuid.Parse(sessionValue)
|
||||
if err != nil {
|
||||
return claims, err
|
||||
}
|
||||
|
||||
claims.UserID = userID
|
||||
claims.Session = sessionID
|
||||
claims.Username, _ = mapClaims["username"].(string)
|
||||
claims.Role, _ = mapClaims["role"].(string)
|
||||
return claims, nil
|
||||
}
|
||||
39
backend/internal/auth/types.go
Normal file
39
backend/internal/auth/types.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package auth
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type Claims struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
Session uuid.UUID `json:"session_id"`
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type RefreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type BootstrapRequest struct {
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
User UserView `json:"user"`
|
||||
}
|
||||
|
||||
type UserView struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
82
backend/internal/config/config.go
Normal file
82
backend/internal/config/config.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
AppName string
|
||||
Environment string
|
||||
HTTPAddress string
|
||||
DatabaseURL string
|
||||
JWTIssuer string
|
||||
JWTSecret string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
DefaultGatewayID string
|
||||
DefaultDNS []string
|
||||
DefaultVPNCIDR string
|
||||
DefaultGatewayHost string
|
||||
DefaultGatewayPubKey string
|
||||
}
|
||||
|
||||
func Load() Config {
|
||||
return Config{
|
||||
AppName: getenv("APP_NAME", "NexaVPN"),
|
||||
Environment: getenv("APP_ENV", "development"),
|
||||
HTTPAddress: getenv("HTTP_ADDRESS", ":8080"),
|
||||
DatabaseURL: getenv("DATABASE_URL", "postgres://nexavpn:nexavpn@localhost:5432/nexavpn?sslmode=disable"),
|
||||
JWTIssuer: getenv("JWT_ISSUER", "nexavpn"),
|
||||
JWTSecret: getenv("JWT_SECRET", "change-me-in-production"),
|
||||
AccessTokenTTL: time.Duration(getenvInt("ACCESS_TOKEN_TTL_SECONDS", 900)) * time.Second,
|
||||
RefreshTokenTTL: time.Duration(getenvInt("REFRESH_TOKEN_TTL_SECONDS", 2592000)) * time.Second,
|
||||
DefaultGatewayID: getenv("DEFAULT_GATEWAY_ID", ""),
|
||||
DefaultDNS: splitCSV(getenv("DEFAULT_DNS_SERVERS", "10.20.0.53")),
|
||||
DefaultVPNCIDR: getenv("DEFAULT_VPN_CIDR", "100.96.0.0/24"),
|
||||
DefaultGatewayHost: getenv("DEFAULT_GATEWAY_ENDPOINT", "vpn.example.com:51820"),
|
||||
DefaultGatewayPubKey: getenv("DEFAULT_GATEWAY_PUBLIC_KEY", "replace-me"),
|
||||
}
|
||||
}
|
||||
|
||||
func getenv(key, fallback string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getenvInt(key string, fallback int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
|
||||
return fallback
|
||||
}
|
||||
|
||||
func splitCSV(value string) []string {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var items []string
|
||||
start := 0
|
||||
for i := range value {
|
||||
if value[i] == ',' {
|
||||
if start < i {
|
||||
items = append(items, value[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(value) {
|
||||
items = append(items, value[start:])
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
11
backend/internal/db/db.go
Normal file
11
backend/internal/db/db.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
func Connect(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) {
|
||||
return pgxpool.New(ctx, databaseURL)
|
||||
}
|
||||
176
backend/internal/device/handler.go
Normal file
176
backend/internal/device/handler.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
func NewHandler(service *Service, auditService *audit.Service) *Handler {
|
||||
return &Handler{service: service, audit: auditService}
|
||||
}
|
||||
|
||||
func (h *Handler) Enroll(w http.ResponseWriter, r *http.Request) {
|
||||
var input EnrollRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := httpserver.MustUserID(r.Context())
|
||||
if !ok {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.service.Enroll(r.Context(), userID, input, "__CLIENT_GENERATED_PRIVATE_KEY__")
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "device_enroll_failed", "unable to enroll device")
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &userID,
|
||||
EntityType: "device",
|
||||
EntityID: &response.Device.ID,
|
||||
EventType: "device.enrolled",
|
||||
Status: "success",
|
||||
Message: "device enrolled and profile issued",
|
||||
Metadata: map[string]any{
|
||||
"platform": response.Device.Platform,
|
||||
"assigned_ip": response.Peer.AssignedIP,
|
||||
},
|
||||
})
|
||||
apiutil.JSON(w, http.StatusCreated, response)
|
||||
}
|
||||
|
||||
func (h *Handler) ListOwn(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := httpserver.MustUserID(r.Context())
|
||||
if !ok {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||
return
|
||||
}
|
||||
|
||||
devices, err := h.service.ListByUser(r.Context(), userID)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "devices_list_failed", "unable to list devices")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, devices)
|
||||
}
|
||||
|
||||
func (h *Handler) ListAll(w http.ResponseWriter, r *http.Request) {
|
||||
devices, err := h.service.ListAll(r.Context())
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "devices_list_failed", "unable to list devices")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, devices)
|
||||
}
|
||||
|
||||
func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := httpserver.MustUserID(r.Context())
|
||||
if !ok {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||
return
|
||||
}
|
||||
|
||||
status, err := h.service.GetConnectionStatus(r.Context(), userID)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "connection_status_failed", "unable to fetch connection status")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, status)
|
||||
}
|
||||
|
||||
func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := httpserver.MustUserID(r.Context())
|
||||
if !ok {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.service.GetLatestEnrollmentByUser(r.Context(), userID)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusNotFound, "profile_not_found", "no active profile found")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) GetProfileByDeviceID(w http.ResponseWriter, r *http.Request) {
|
||||
deviceID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_device_id", "invalid device id")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.service.GetEnrollmentByDeviceID(r.Context(), deviceID)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusNotFound, "profile_not_found", "device profile not found")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
deviceID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_device_id", "invalid device id")
|
||||
return
|
||||
}
|
||||
if err := h.service.Revoke(r.Context(), deviceID); err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "device_revoke_failed", "unable to revoke device")
|
||||
return
|
||||
}
|
||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &claims.UserID,
|
||||
EntityType: "device",
|
||||
EntityID: &deviceID,
|
||||
EventType: "admin.device.revoked",
|
||||
Status: "success",
|
||||
Message: "admin revoked device",
|
||||
})
|
||||
}
|
||||
apiutil.JSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func (h *Handler) Rotate(w http.ResponseWriter, r *http.Request) {
|
||||
deviceID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_device_id", "invalid device id")
|
||||
return
|
||||
}
|
||||
if err := h.service.Rotate(r.Context(), deviceID); err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "device_rotate_failed", "unable to rotate device profile")
|
||||
return
|
||||
}
|
||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &claims.UserID,
|
||||
EntityType: "device",
|
||||
EntityID: &deviceID,
|
||||
EventType: "admin.device.rotated",
|
||||
Status: "success",
|
||||
Message: "admin rotated device profile",
|
||||
})
|
||||
}
|
||||
apiutil.JSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
265
backend/internal/device/repository.go
Normal file
265
backend/internal/device/repository.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error)
|
||||
ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error)
|
||||
ListAll(ctx context.Context) ([]Device, error)
|
||||
GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error)
|
||||
GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error)
|
||||
Revoke(ctx context.Context, deviceID uuid.UUID) error
|
||||
Rotate(ctx context.Context, deviceID uuid.UUID) error
|
||||
}
|
||||
|
||||
type PGRepository struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
||||
return &PGRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *PGRepository) Enroll(ctx context.Context, userID uuid.UUID, gatewayID uuid.UUID, input EnrollRequest, assignedIP string, dnsServers []string, allowedIPs []string) (EnrollmentResponse, error) {
|
||||
tx, err := r.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
deviceID := uuid.New()
|
||||
peerID := uuid.New()
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
insert into devices (
|
||||
id, user_id, gateway_id, name, platform, os_version, app_version, device_fingerprint, public_key, status, approved_at
|
||||
) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', now())
|
||||
`, deviceID, userID, gatewayID, input.Name, input.Platform, input.OSVersion, input.AppVersion, input.DeviceFingerprint, input.PublicKey)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
insert into wireguard_peers (
|
||||
id, device_id, gateway_id, public_key, assigned_ip, allowed_ips, dns_servers, last_profile_issued_at
|
||||
) values ($1, $2, $3, $4, $5::inet, $6::cidr[], $7::text[], now())
|
||||
`, peerID, deviceID, gatewayID, input.PublicKey, assignedIP, allowedIPs, dnsServers)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
insert into ip_allocations (id, gateway_id, device_id, address, status)
|
||||
values ($1, $2, $3, $4::inet, 'allocated')
|
||||
`, uuid.New(), gatewayID, deviceID, assignedIP)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
return EnrollmentResponse{
|
||||
Device: Device{
|
||||
ID: deviceID,
|
||||
UserID: userID,
|
||||
GatewayID: gatewayID,
|
||||
Name: input.Name,
|
||||
Platform: input.Platform,
|
||||
Status: "active",
|
||||
AssignedIP: assignedIP,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *PGRepository) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
|
||||
row := r.db.QueryRow(ctx, `
|
||||
select
|
||||
d.id,
|
||||
d.user_id,
|
||||
d.gateway_id,
|
||||
d.name,
|
||||
d.platform,
|
||||
d.status,
|
||||
host(wp.assigned_ip),
|
||||
wp.profile_revision,
|
||||
wp.last_profile_issued_at,
|
||||
g.name,
|
||||
g.endpoint,
|
||||
g.public_key,
|
||||
wp.dns_servers,
|
||||
wp.allowed_ips
|
||||
from devices d
|
||||
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
||||
join gateways g on g.id = wp.gateway_id
|
||||
where d.user_id = $1 and d.deleted_at is null
|
||||
order by d.created_at desc
|
||||
limit 1
|
||||
`, userID)
|
||||
return scanEnrollmentRow(row)
|
||||
}
|
||||
|
||||
func (r *PGRepository) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
|
||||
row := r.db.QueryRow(ctx, `
|
||||
select
|
||||
d.id,
|
||||
d.user_id,
|
||||
d.gateway_id,
|
||||
d.name,
|
||||
d.platform,
|
||||
d.status,
|
||||
host(wp.assigned_ip),
|
||||
wp.profile_revision,
|
||||
wp.last_profile_issued_at,
|
||||
g.name,
|
||||
g.endpoint,
|
||||
g.public_key,
|
||||
wp.dns_servers,
|
||||
wp.allowed_ips
|
||||
from devices d
|
||||
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
||||
join gateways g on g.id = wp.gateway_id
|
||||
where d.id = $1 and d.deleted_at is null
|
||||
`, deviceID)
|
||||
return scanEnrollmentRow(row)
|
||||
}
|
||||
|
||||
func (r *PGRepository) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
|
||||
rows, err := r.db.Query(ctx, `
|
||||
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
|
||||
from devices d
|
||||
left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
||||
where d.user_id = $1 and d.deleted_at is null
|
||||
order by d.created_at desc
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []Device
|
||||
for rows.Next() {
|
||||
var item Device
|
||||
if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
return items, rows.Err()
|
||||
}
|
||||
|
||||
func (r *PGRepository) ListAll(ctx context.Context) ([]Device, error) {
|
||||
rows, err := r.db.Query(ctx, `
|
||||
select d.id, d.user_id, d.gateway_id, d.name, d.platform, d.status, coalesce(host(wp.assigned_ip), '')
|
||||
from devices d
|
||||
left join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
||||
where d.deleted_at is null
|
||||
order by d.created_at desc
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []Device
|
||||
for rows.Next() {
|
||||
var item Device
|
||||
if err := rows.Scan(&item.ID, &item.UserID, &item.GatewayID, &item.Name, &item.Platform, &item.Status, &item.AssignedIP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
return items, rows.Err()
|
||||
}
|
||||
|
||||
func (r *PGRepository) Revoke(ctx context.Context, deviceID uuid.UUID) error {
|
||||
tx, err := r.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
if _, err := tx.Exec(ctx, `update devices set status = 'revoked', revoked_at = now(), updated_at = now() where id = $1`, deviceID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `update wireguard_peers set deleted_at = now(), updated_at = now() where device_id = $1 and deleted_at is null`, deviceID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `update ip_allocations set status = 'released', released_at = now(), updated_at = now() where device_id = $1 and status = 'allocated'`, deviceID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (r *PGRepository) Rotate(ctx context.Context, deviceID uuid.UUID) error {
|
||||
_, err := r.db.Exec(ctx, `
|
||||
update wireguard_peers
|
||||
set profile_revision = profile_revision + 1, last_profile_issued_at = now(), updated_at = now()
|
||||
where device_id = $1 and deleted_at is null
|
||||
`, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
type enrollmentRowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanEnrollmentRow(row enrollmentRowScanner) (EnrollmentResponse, error) {
|
||||
var response EnrollmentResponse
|
||||
var profileRevision int
|
||||
var lastIssuedAt *time.Time
|
||||
var gatewayName string
|
||||
var gatewayEndpoint string
|
||||
var gatewayPublicKey string
|
||||
var dnsServers []string
|
||||
var allowedIPs []string
|
||||
|
||||
if err := row.Scan(
|
||||
&response.Device.ID,
|
||||
&response.Device.UserID,
|
||||
&response.Device.GatewayID,
|
||||
&response.Device.Name,
|
||||
&response.Device.Platform,
|
||||
&response.Device.Status,
|
||||
&response.Device.AssignedIP,
|
||||
&profileRevision,
|
||||
&lastIssuedAt,
|
||||
&gatewayName,
|
||||
&gatewayEndpoint,
|
||||
&gatewayPublicKey,
|
||||
&dnsServers,
|
||||
&allowedIPs,
|
||||
); err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
response.Peer = PeerView{
|
||||
AssignedIP: response.Device.AssignedIP,
|
||||
DNSServers: dnsServers,
|
||||
AllowedIPs: allowedIPs,
|
||||
Gateway: GatewayView{
|
||||
ID: response.Device.GatewayID,
|
||||
Name: gatewayName,
|
||||
Endpoint: gatewayEndpoint,
|
||||
PublicKey: gatewayPublicKey,
|
||||
},
|
||||
ProfileRevision: profileRevision,
|
||||
}
|
||||
for _, destination := range allowedIPs {
|
||||
response.Resources = append(response.Resources, Resource{
|
||||
Type: "cidr",
|
||||
Value: destination,
|
||||
Label: destination,
|
||||
})
|
||||
}
|
||||
_ = lastIssuedAt
|
||||
return response, nil
|
||||
}
|
||||
130
backend/internal/device/service.go
Normal file
130
backend/internal/device/service.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/gateway"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/ipam"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/policy"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/profile"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repo Repository
|
||||
policyService *policy.Service
|
||||
gatewayService *gateway.Service
|
||||
ipamService *ipam.Service
|
||||
}
|
||||
|
||||
func NewService(repo Repository, policyService *policy.Service, gatewayService *gateway.Service, ipamService *ipam.Service) *Service {
|
||||
return &Service{
|
||||
repo: repo,
|
||||
policyService: policyService,
|
||||
gatewayService: gatewayService,
|
||||
ipamService: ipamService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Enroll(ctx context.Context, userID uuid.UUID, input EnrollRequest, privateKeyPlaceholder string) (EnrollmentResponse, error) {
|
||||
selectedGateway, err := s.gatewayService.SelectActive(ctx)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
assignedIP, err := s.ipamService.Allocate(selectedGateway.VPNCIDR, 10)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
enrollment, err := s.repo.Enroll(ctx, userID, selectedGateway.ID, input, assignedIP, selectedGateway.DNSServers, nil)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
|
||||
destinations, err := s.policyService.ResolveDestinations(ctx, userID, &enrollment.Device.ID)
|
||||
if err != nil {
|
||||
return EnrollmentResponse{}, err
|
||||
}
|
||||
if len(destinations) == 0 {
|
||||
destinations = []string{"172.16.10.0/24"}
|
||||
}
|
||||
|
||||
enrollment.Peer = PeerView{
|
||||
AssignedIP: assignedIP,
|
||||
DNSServers: selectedGateway.DNSServers,
|
||||
AllowedIPs: destinations,
|
||||
Gateway: GatewayView{
|
||||
ID: selectedGateway.ID,
|
||||
Name: selectedGateway.Name,
|
||||
Endpoint: selectedGateway.Endpoint,
|
||||
PublicKey: selectedGateway.PublicKey,
|
||||
},
|
||||
ProfileRevision: 1,
|
||||
}
|
||||
for _, destination := range destinations {
|
||||
enrollment.Resources = append(enrollment.Resources, Resource{
|
||||
Type: "cidr",
|
||||
Value: destination,
|
||||
Label: destination,
|
||||
})
|
||||
}
|
||||
|
||||
enrollment.Profile = ProfileView{
|
||||
Format: "wireguard",
|
||||
Content: profile.BuildWireGuardConfig(profile.BuildInput{
|
||||
PrivateKey: privateKeyPlaceholder,
|
||||
Address: assignedIP,
|
||||
DNSServers: selectedGateway.DNSServers,
|
||||
ServerPublicKey: selectedGateway.PublicKey,
|
||||
ServerEndpoint: selectedGateway.Endpoint,
|
||||
AllowedIPs: destinations,
|
||||
PersistentKeepal: 25,
|
||||
}),
|
||||
}
|
||||
|
||||
return enrollment, nil
|
||||
}
|
||||
|
||||
func (s *Service) ListByUser(ctx context.Context, userID uuid.UUID) ([]Device, error) {
|
||||
return s.repo.ListByUser(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *Service) ListAll(ctx context.Context) ([]Device, error) {
|
||||
return s.repo.ListAll(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) GetLatestEnrollmentByUser(ctx context.Context, userID uuid.UUID) (EnrollmentResponse, error) {
|
||||
return s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *Service) GetEnrollmentByDeviceID(ctx context.Context, deviceID uuid.UUID) (EnrollmentResponse, error) {
|
||||
return s.repo.GetEnrollmentByDeviceID(ctx, deviceID)
|
||||
}
|
||||
|
||||
func (s *Service) GetConnectionStatus(ctx context.Context, userID uuid.UUID) (ConnectionStatus, error) {
|
||||
enrollment, err := s.repo.GetLatestEnrollmentByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return ConnectionStatus{
|
||||
Status: "disconnected",
|
||||
Resources: []Resource{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
lastSync := "just now"
|
||||
return ConnectionStatus{
|
||||
Status: "provisioned",
|
||||
AssignedIP: enrollment.Peer.AssignedIP,
|
||||
LastSyncTime: &lastSync,
|
||||
Resources: enrollment.Resources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Revoke(ctx context.Context, deviceID uuid.UUID) error {
|
||||
return s.repo.Revoke(ctx, deviceID)
|
||||
}
|
||||
|
||||
func (s *Service) Rotate(ctx context.Context, deviceID uuid.UUID) error {
|
||||
return s.repo.Rotate(ctx, deviceID)
|
||||
}
|
||||
62
backend/internal/device/types.go
Normal file
62
backend/internal/device/types.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package device
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type EnrollRequest struct {
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
OSVersion string `json:"os_version"`
|
||||
AppVersion string `json:"app_version"`
|
||||
DeviceFingerprint string `json:"device_fingerprint"`
|
||||
PublicKey string `json:"public_key"`
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id,omitempty"`
|
||||
GatewayID uuid.UUID `json:"gateway_id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
AssignedIP string `json:"assigned_ip,omitempty"`
|
||||
}
|
||||
|
||||
type ConnectionStatus struct {
|
||||
Status string `json:"status"`
|
||||
AssignedIP string `json:"assigned_ip"`
|
||||
LastSyncTime *string `json:"last_sync_time"`
|
||||
Resources []Resource `json:"resources"`
|
||||
}
|
||||
|
||||
type Resource struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
Label string `json:"label"`
|
||||
}
|
||||
|
||||
type EnrollmentResponse struct {
|
||||
Device Device `json:"device"`
|
||||
Peer PeerView `json:"peer"`
|
||||
Profile ProfileView `json:"profile"`
|
||||
Resources []Resource `json:"resources"`
|
||||
}
|
||||
|
||||
type PeerView struct {
|
||||
AssignedIP string `json:"assigned_ip"`
|
||||
DNSServers []string `json:"dns_servers"`
|
||||
AllowedIPs []string `json:"allowed_ips"`
|
||||
Gateway GatewayView `json:"gateway"`
|
||||
ProfileRevision int `json:"profile_revision"`
|
||||
}
|
||||
|
||||
type GatewayView struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"public_key"`
|
||||
}
|
||||
|
||||
type ProfileView struct {
|
||||
Format string `json:"format"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
37
backend/internal/gateway/handler.go
Normal file
37
backend/internal/gateway/handler.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewHandler(service *Service) *Handler {
|
||||
return &Handler{service: service}
|
||||
}
|
||||
|
||||
func (h *Handler) List(w http.ResponseWriter, r *http.Request) {
|
||||
items, err := h.service.List(r.Context())
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "gateways_list_failed", "unable to list gateways")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, items)
|
||||
}
|
||||
|
||||
func (h *Handler) SyncBundle(w http.ResponseWriter, r *http.Request) {
|
||||
bundle, err := h.service.BuildSyncBundle(r.Context(), chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "gateway_sync_failed", "unable to build sync bundle")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, bundle)
|
||||
}
|
||||
102
backend/internal/gateway/repository.go
Normal file
102
backend/internal/gateway/repository.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/wireguard"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
List(ctx context.Context) ([]Gateway, error)
|
||||
FirstActive(ctx context.Context) (Gateway, error)
|
||||
BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error)
|
||||
}
|
||||
|
||||
type PGRepository struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
||||
return &PGRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *PGRepository) List(ctx context.Context) ([]Gateway, error) {
|
||||
rows, err := r.db.Query(ctx, `
|
||||
select id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active
|
||||
from gateways
|
||||
where deleted_at is null
|
||||
order by created_at desc
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []Gateway
|
||||
for rows.Next() {
|
||||
var item Gateway
|
||||
if err := rows.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
return items, rows.Err()
|
||||
}
|
||||
|
||||
func (r *PGRepository) FirstActive(ctx context.Context) (Gateway, error) {
|
||||
row := r.db.QueryRow(ctx, `
|
||||
select id, name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active
|
||||
from gateways
|
||||
where deleted_at is null and is_active = true
|
||||
order by created_at asc
|
||||
limit 1
|
||||
`)
|
||||
|
||||
var item Gateway
|
||||
err := row.Scan(&item.ID, &item.Name, &item.Endpoint, &item.PublicKey, &item.ListenPort, &item.VPNCIDR, &item.DNSServers, &item.IsActive)
|
||||
return item, err
|
||||
}
|
||||
|
||||
func (r *PGRepository) BuildSyncBundle(ctx context.Context, gatewayID uuid.UUID) (wireguard.GatewayBundle, error) {
|
||||
var bundle wireguard.GatewayBundle
|
||||
bundle.GatewayID = gatewayID.String()
|
||||
bundle.Revision = 1
|
||||
|
||||
row := r.db.QueryRow(ctx, `
|
||||
select host(vpn_cidr), listen_port
|
||||
from gateways
|
||||
where id = $1 and deleted_at is null
|
||||
`, gatewayID)
|
||||
if err := row.Scan(&bundle.Interface.Address, &bundle.Interface.ListenPort); err != nil {
|
||||
return wireguard.GatewayBundle{}, err
|
||||
}
|
||||
|
||||
rows, err := r.db.Query(ctx, `
|
||||
select d.id, wp.public_key, host(wp.assigned_ip), coalesce(array_agg(pd.destination::text) filter (where pd.destination is not null), '{}')
|
||||
from devices d
|
||||
join wireguard_peers wp on wp.device_id = d.id and wp.deleted_at is null
|
||||
left join policy_targets pt on pt.target_id = d.id and pt.target_type = 'device'
|
||||
left join policy_destinations pd on pd.policy_id = pt.policy_id
|
||||
where d.gateway_id = $1 and d.deleted_at is null and d.status = 'active'
|
||||
group by d.id, wp.public_key, wp.assigned_ip
|
||||
`, gatewayID)
|
||||
if err != nil {
|
||||
return wireguard.GatewayBundle{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var peer wireguard.Peer
|
||||
var deviceID uuid.UUID
|
||||
if err := rows.Scan(&deviceID, &peer.PublicKey, &peer.AssignedIP, &peer.AllowedDestinations); err != nil {
|
||||
return wireguard.GatewayBundle{}, err
|
||||
}
|
||||
peer.DeviceID = deviceID.String()
|
||||
bundle.Peers = append(bundle.Peers, peer)
|
||||
}
|
||||
|
||||
return bundle, rows.Err()
|
||||
}
|
||||
33
backend/internal/gateway/service.go
Normal file
33
backend/internal/gateway/service.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/wireguard"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repo Repository
|
||||
}
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repo: repo}
|
||||
}
|
||||
|
||||
func (s *Service) List(ctx context.Context) ([]Gateway, error) {
|
||||
return s.repo.List(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) SelectActive(ctx context.Context) (Gateway, error) {
|
||||
return s.repo.FirstActive(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) BuildSyncBundle(ctx context.Context, gatewayID string) (wireguard.GatewayBundle, error) {
|
||||
id, err := uuid.Parse(gatewayID)
|
||||
if err != nil {
|
||||
return wireguard.GatewayBundle{}, err
|
||||
}
|
||||
return s.repo.BuildSyncBundle(ctx, id)
|
||||
}
|
||||
14
backend/internal/gateway/types.go
Normal file
14
backend/internal/gateway/types.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package gateway
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type Gateway struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"public_key"`
|
||||
ListenPort int `json:"listen_port"`
|
||||
VPNCIDR string `json:"vpn_cidr"`
|
||||
DNSServers []string `json:"dns_servers"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
66
backend/internal/httpserver/middleware.go
Normal file
66
backend/internal/httpserver/middleware.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/auth"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const claimsContextKey contextKey = "claims"
|
||||
|
||||
func BaseMiddleware(next http.Handler) http.Handler {
|
||||
return middleware.RealIP(middleware.RequestID(middleware.Logger(next)))
|
||||
}
|
||||
|
||||
func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing bearer token")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := auth.ParseAccessToken(jwtSecret, strings.TrimPrefix(header, "Bearer "))
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "invalid access token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), claimsContextKey, claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func AdminOnly(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := ClaimsFromContext(r.Context())
|
||||
if !ok || claims.Role != "admin" {
|
||||
apiutil.Error(w, http.StatusForbidden, "forbidden", "admin role required")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
|
||||
claims, ok := ctx.Value(claimsContextKey).(auth.Claims)
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
func MustUserID(ctx context.Context) (uuid.UUID, bool) {
|
||||
claims, ok := ClaimsFromContext(ctx)
|
||||
if !ok {
|
||||
return uuid.Nil, false
|
||||
}
|
||||
return claims.UserID, true
|
||||
}
|
||||
68
backend/internal/httpserver/router.go
Normal file
68
backend/internal/httpserver/router.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/auth"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/device"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/gateway"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/policy"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/user"
|
||||
)
|
||||
|
||||
type Handlers struct {
|
||||
Auth *auth.Handler
|
||||
User *user.Handler
|
||||
Device *device.Handler
|
||||
Policy *policy.Handler
|
||||
Gateway *gateway.Handler
|
||||
Audit *audit.Handler
|
||||
}
|
||||
|
||||
func NewRouter(jwtSecret string, handlers Handlers) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(BaseMiddleware)
|
||||
|
||||
r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
apiutil.JSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
})
|
||||
|
||||
r.Route("/api/v1", func(r chi.Router) {
|
||||
r.Post("/auth/bootstrap", handlers.Auth.Bootstrap)
|
||||
r.Post("/auth/login", handlers.Auth.Login)
|
||||
r.Post("/auth/refresh", handlers.Auth.Refresh)
|
||||
r.Post("/auth/logout", handlers.Auth.Logout)
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(AuthMiddleware(jwtSecret))
|
||||
r.Get("/auth/me", handlers.Auth.Me)
|
||||
r.Post("/devices/enroll", handlers.Device.Enroll)
|
||||
r.Get("/me/devices", handlers.Device.ListOwn)
|
||||
r.Get("/me/profile", handlers.Device.GetOwnProfile)
|
||||
r.Get("/connection/status", handlers.Device.ConnectionStatus)
|
||||
|
||||
r.Route("/admin", func(r chi.Router) {
|
||||
r.Use(AdminOnly)
|
||||
r.Get("/users", handlers.User.List)
|
||||
r.Post("/users", handlers.User.Create)
|
||||
r.Post("/users/{id}/disable", handlers.User.Disable)
|
||||
r.Post("/users/{id}/enable", handlers.User.Enable)
|
||||
r.Get("/devices", handlers.Device.ListAll)
|
||||
r.Get("/devices/{id}/profile", handlers.Device.GetProfileByDeviceID)
|
||||
r.Post("/devices/{id}/revoke", handlers.Device.Revoke)
|
||||
r.Post("/devices/{id}/rotate", handlers.Device.Rotate)
|
||||
r.Get("/policies", handlers.Policy.List)
|
||||
r.Post("/policies", handlers.Policy.Create)
|
||||
r.Get("/gateways", handlers.Gateway.List)
|
||||
r.Get("/gateways/{id}/sync", handlers.Gateway.SyncBundle)
|
||||
r.Get("/audit-logs", handlers.Audit.List)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
26
backend/internal/ipam/service.go
Normal file
26
backend/internal/ipam/service.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package ipam
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Service struct{}
|
||||
|
||||
func NewService() *Service {
|
||||
return &Service{}
|
||||
}
|
||||
|
||||
func (s *Service) Allocate(cidr string, offset int) (string, error) {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
address := prefix.Addr().Next()
|
||||
for i := 1; i < offset; i++ {
|
||||
address = address.Next()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/32", address.String()), nil
|
||||
}
|
||||
62
backend/internal/policy/handler.go
Normal file
62
backend/internal/policy/handler.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
func NewHandler(service *Service, auditService *audit.Service) *Handler {
|
||||
return &Handler{service: service, audit: auditService}
|
||||
}
|
||||
|
||||
func (h *Handler) List(w http.ResponseWriter, r *http.Request) {
|
||||
items, err := h.service.List(r.Context())
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "policies_list_failed", "unable to list policies")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, items)
|
||||
}
|
||||
|
||||
func (h *Handler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var input CreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := httpserver.ClaimsFromContext(r.Context())
|
||||
if !ok {
|
||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||
return
|
||||
}
|
||||
|
||||
item, err := h.service.Create(r.Context(), claims.UserID, input)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "policy_create_failed", "unable to create policy")
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &claims.UserID,
|
||||
EntityType: "policy",
|
||||
EntityID: &item.ID,
|
||||
EventType: "admin.policy.created",
|
||||
Status: "success",
|
||||
Message: "admin created policy",
|
||||
Metadata: map[string]any{
|
||||
"name": item.Name,
|
||||
},
|
||||
})
|
||||
apiutil.JSON(w, http.StatusCreated, item)
|
||||
}
|
||||
143
backend/internal/policy/repository.go
Normal file
143
backend/internal/policy/repository.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
List(ctx context.Context) ([]Policy, error)
|
||||
Create(ctx context.Context, input CreateRequest, createdBy uuid.UUID) (Policy, error)
|
||||
ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error)
|
||||
}
|
||||
|
||||
type PGRepository struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
||||
return &PGRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *PGRepository) List(ctx context.Context) ([]Policy, error) {
|
||||
rows, err := r.db.Query(ctx, `
|
||||
select
|
||||
p.id,
|
||||
p.name,
|
||||
p.description,
|
||||
p.priority,
|
||||
p.effect,
|
||||
p.full_tunnel,
|
||||
p.is_active,
|
||||
coalesce(array_agg(pd.destination::text) filter (where pd.destination is not null), '{}')
|
||||
from policies p
|
||||
left join policy_destinations pd on pd.policy_id = p.id
|
||||
where p.deleted_at is null
|
||||
group by p.id, p.name, p.description, p.priority, p.effect, p.full_tunnel, p.is_active, p.created_at
|
||||
order by p.priority asc, p.created_at desc
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []Policy
|
||||
for rows.Next() {
|
||||
var item Policy
|
||||
if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.Priority, &item.Effect, &item.FullTunnel, &item.IsActive, &item.Destinations); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
return items, rows.Err()
|
||||
}
|
||||
|
||||
func (r *PGRepository) Create(ctx context.Context, input CreateRequest, createdBy uuid.UUID) (Policy, error) {
|
||||
tx, err := r.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
policyID := uuid.New()
|
||||
_, err = tx.Exec(ctx, `
|
||||
insert into policies (id, name, description, priority, effect, full_tunnel, created_by)
|
||||
values ($1, $2, $3, $4, $5, $6, $7)
|
||||
`, policyID, input.Name, input.Description, input.Priority, input.Effect, input.FullTunnel, createdBy)
|
||||
if err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
|
||||
for _, destination := range input.Destinations {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
insert into policy_destinations (id, policy_id, destination)
|
||||
values ($1, $2, $3::cidr)
|
||||
`, uuid.New(), policyID, destination); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, target := range input.Targets {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
insert into policy_targets (id, policy_id, target_type, target_id)
|
||||
values ($1, $2, $3, $4)
|
||||
`, uuid.New(), policyID, target.Type, target.ID); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return Policy{}, err
|
||||
}
|
||||
|
||||
inputPolicy := Policy{
|
||||
ID: policyID,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Priority: input.Priority,
|
||||
Effect: input.Effect,
|
||||
FullTunnel: input.FullTunnel,
|
||||
IsActive: true,
|
||||
Destinations: input.Destinations,
|
||||
Targets: input.Targets,
|
||||
}
|
||||
return inputPolicy, nil
|
||||
}
|
||||
|
||||
func (r *PGRepository) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) {
|
||||
query := `
|
||||
select distinct pd.destination::text
|
||||
from policies p
|
||||
join policy_destinations pd on pd.policy_id = p.id
|
||||
join policy_targets pt on pt.policy_id = p.id
|
||||
where p.deleted_at is null
|
||||
and p.is_active = true
|
||||
and p.effect = 'allow'
|
||||
and (
|
||||
(pt.target_type = 'user' and pt.target_id = $1)
|
||||
`
|
||||
args := []any{userID}
|
||||
if deviceID != nil {
|
||||
query += ` or (pt.target_type = 'device' and pt.target_id = $2)`
|
||||
args = append(args, *deviceID)
|
||||
}
|
||||
query += `)`
|
||||
|
||||
rows, err := r.db.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var destinations []string
|
||||
for rows.Next() {
|
||||
var value string
|
||||
if err := rows.Scan(&value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destinations = append(destinations, value)
|
||||
}
|
||||
return destinations, rows.Err()
|
||||
}
|
||||
33
backend/internal/policy/service.go
Normal file
33
backend/internal/policy/service.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repo Repository
|
||||
}
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repo: repo}
|
||||
}
|
||||
|
||||
func (s *Service) List(ctx context.Context) ([]Policy, error) {
|
||||
return s.repo.List(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) Create(ctx context.Context, actorID uuid.UUID, input CreateRequest) (Policy, error) {
|
||||
if input.Priority == 0 {
|
||||
input.Priority = 100
|
||||
}
|
||||
if input.Effect == "" {
|
||||
input.Effect = "allow"
|
||||
}
|
||||
return s.repo.Create(ctx, input, actorID)
|
||||
}
|
||||
|
||||
func (s *Service) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) {
|
||||
return s.repo.ResolveDestinations(ctx, userID, deviceID)
|
||||
}
|
||||
30
backend/internal/policy/types.go
Normal file
30
backend/internal/policy/types.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package policy
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type Target struct {
|
||||
Type string `json:"type"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
}
|
||||
|
||||
type Policy struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Priority int `json:"priority"`
|
||||
Effect string `json:"effect"`
|
||||
FullTunnel bool `json:"full_tunnel"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Destinations []string `json:"destinations"`
|
||||
Targets []Target `json:"targets"`
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Priority int `json:"priority"`
|
||||
Effect string `json:"effect"`
|
||||
FullTunnel bool `json:"full_tunnel"`
|
||||
Destinations []string `json:"destinations"`
|
||||
Targets []Target `json:"targets"`
|
||||
}
|
||||
33
backend/internal/profile/builder.go
Normal file
33
backend/internal/profile/builder.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package profile
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type BuildInput struct {
|
||||
PrivateKey string
|
||||
Address string
|
||||
DNSServers []string
|
||||
ServerPublicKey string
|
||||
ServerEndpoint string
|
||||
AllowedIPs []string
|
||||
PersistentKeepal int
|
||||
}
|
||||
|
||||
func BuildWireGuardConfig(input BuildInput) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("[Interface]\n")
|
||||
b.WriteString(fmt.Sprintf("PrivateKey = %s\n", input.PrivateKey))
|
||||
b.WriteString(fmt.Sprintf("Address = %s\n", input.Address))
|
||||
if len(input.DNSServers) > 0 {
|
||||
b.WriteString(fmt.Sprintf("DNS = %s\n", strings.Join(input.DNSServers, ", ")))
|
||||
}
|
||||
|
||||
b.WriteString("\n[Peer]\n")
|
||||
b.WriteString(fmt.Sprintf("PublicKey = %s\n", input.ServerPublicKey))
|
||||
b.WriteString(fmt.Sprintf("Endpoint = %s\n", input.ServerEndpoint))
|
||||
b.WriteString(fmt.Sprintf("AllowedIPs = %s\n", strings.Join(input.AllowedIPs, ", ")))
|
||||
b.WriteString(fmt.Sprintf("PersistentKeepalive = %d\n", input.PersistentKeepal))
|
||||
return b.String()
|
||||
}
|
||||
110
backend/internal/user/handler.go
Normal file
110
backend/internal/user/handler.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
func NewHandler(service *Service, auditService *audit.Service) *Handler {
|
||||
return &Handler{service: service, audit: auditService}
|
||||
}
|
||||
|
||||
func (h *Handler) List(w http.ResponseWriter, r *http.Request) {
|
||||
users, err := h.service.List(r.Context())
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "users_list_failed", "unable to list users")
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(w, http.StatusOK, users)
|
||||
}
|
||||
|
||||
func (h *Handler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var input CreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_json", "invalid request body")
|
||||
return
|
||||
}
|
||||
if input.Role == "" {
|
||||
input.Role = "user"
|
||||
}
|
||||
|
||||
created, err := h.service.Create(r.Context(), input)
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusInternalServerError, "user_create_failed", "unable to create user")
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &claims.UserID,
|
||||
EntityType: "user",
|
||||
EntityID: &created.ID,
|
||||
EventType: "admin.user.created",
|
||||
Status: "success",
|
||||
Message: "admin created user",
|
||||
Metadata: map[string]any{
|
||||
"username": created.Username,
|
||||
},
|
||||
})
|
||||
}
|
||||
apiutil.JSON(w, http.StatusCreated, created)
|
||||
}
|
||||
|
||||
func (h *Handler) Disable(w http.ResponseWriter, r *http.Request) {
|
||||
targetID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_user_id", "invalid user id")
|
||||
return
|
||||
}
|
||||
if err := h.service.SetActive(r.Context(), targetID.String(), false); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "user_disable_failed", "unable to disable user")
|
||||
return
|
||||
}
|
||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &claims.UserID,
|
||||
EntityType: "user",
|
||||
EntityID: &targetID,
|
||||
EventType: "admin.user.disabled",
|
||||
Status: "success",
|
||||
Message: "admin disabled user",
|
||||
})
|
||||
}
|
||||
apiutil.JSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func (h *Handler) Enable(w http.ResponseWriter, r *http.Request) {
|
||||
targetID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "invalid_user_id", "invalid user id")
|
||||
return
|
||||
}
|
||||
if err := h.service.SetActive(r.Context(), targetID.String(), true); err != nil {
|
||||
apiutil.Error(w, http.StatusBadRequest, "user_enable_failed", "unable to enable user")
|
||||
return
|
||||
}
|
||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||
ActorUserID: &claims.UserID,
|
||||
EntityType: "user",
|
||||
EntityID: &targetID,
|
||||
EventType: "admin.user.enabled",
|
||||
Status: "success",
|
||||
Message: "admin enabled user",
|
||||
})
|
||||
}
|
||||
apiutil.JSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
64
backend/internal/user/repository.go
Normal file
64
backend/internal/user/repository.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
List(ctx context.Context) ([]User, error)
|
||||
Create(ctx context.Context, input CreateRequest, passwordHash string) (User, error)
|
||||
SetActive(ctx context.Context, userID uuid.UUID, active bool) error
|
||||
}
|
||||
|
||||
type PGRepository struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPGRepository(db *pgxpool.Pool) *PGRepository {
|
||||
return &PGRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *PGRepository) List(ctx context.Context) ([]User, error) {
|
||||
rows, err := r.db.Query(ctx, `
|
||||
select u.id, r.id, r.name, u.username, u.display_name, coalesce(u.email, ''), u.is_active
|
||||
from users u
|
||||
join roles r on r.id = u.role_id
|
||||
where u.deleted_at is null
|
||||
order by u.created_at desc
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []User
|
||||
for rows.Next() {
|
||||
var item User
|
||||
if err := rows.Scan(&item.ID, &item.RoleID, &item.RoleName, &item.Username, &item.DisplayName, &item.Email, &item.IsActive); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, item)
|
||||
}
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
func (r *PGRepository) Create(ctx context.Context, input CreateRequest, passwordHash string) (User, error) {
|
||||
const query = `
|
||||
insert into users (id, role_id, username, display_name, email, password_hash)
|
||||
values ($1, (select id from roles where name = $2), $3, $4, nullif($5, ''), $6)
|
||||
returning id, username, display_name, coalesce(email, ''), is_active
|
||||
`
|
||||
|
||||
item := User{RoleName: input.Role}
|
||||
err := r.db.QueryRow(ctx, query, uuid.New(), input.Role, input.Username, input.DisplayName, input.Email, passwordHash).
|
||||
Scan(&item.ID, &item.Username, &item.DisplayName, &item.Email, &item.IsActive)
|
||||
return item, err
|
||||
}
|
||||
|
||||
func (r *PGRepository) SetActive(ctx context.Context, userID uuid.UUID, active bool) error {
|
||||
_, err := r.db.Exec(ctx, `update users set is_active = $2, updated_at = now() where id = $1`, userID, active)
|
||||
return err
|
||||
}
|
||||
37
backend/internal/user/service.go
Normal file
37
backend/internal/user/service.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/nexavpn/nexavpn/backend/internal/auth"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repo Repository
|
||||
}
|
||||
|
||||
func NewService(repo Repository) *Service {
|
||||
return &Service{repo: repo}
|
||||
}
|
||||
|
||||
func (s *Service) List(ctx context.Context) ([]User, error) {
|
||||
return s.repo.List(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) Create(ctx context.Context, input CreateRequest) (User, error) {
|
||||
passwordHash, err := auth.HashPassword(input.Password)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return s.repo.Create(ctx, input, passwordHash)
|
||||
}
|
||||
|
||||
func (s *Service) SetActive(ctx context.Context, userID string, active bool) error {
|
||||
id, err := uuid.Parse(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repo.SetActive(ctx, id, active)
|
||||
}
|
||||
27
backend/internal/user/types.go
Normal file
27
backend/internal/user/types.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package user
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
RoleID uuid.UUID `json:"role_id,omitempty"`
|
||||
RoleName string `json:"role"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Email string `json:"email,omitempty"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
DisplayName *string `json:"display_name"`
|
||||
Email *string `json:"email"`
|
||||
IsActive *bool `json:"is_active"`
|
||||
}
|
||||
18
backend/internal/wireguard/types.go
Normal file
18
backend/internal/wireguard/types.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package wireguard
|
||||
|
||||
type Peer struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
PublicKey string `json:"public_key"`
|
||||
AssignedIP string `json:"assigned_ip"`
|
||||
AllowedDestinations []string `json:"allowed_destinations"`
|
||||
}
|
||||
|
||||
type GatewayBundle struct {
|
||||
GatewayID string `json:"gateway_id"`
|
||||
Revision int `json:"revision"`
|
||||
Interface struct {
|
||||
Address string `json:"address"`
|
||||
ListenPort int `json:"listen_port"`
|
||||
} `json:"interface"`
|
||||
Peers []Peer `json:"peers"`
|
||||
}
|
||||
183
backend/migrations/000001_init.sql
Normal file
183
backend/migrations/000001_init.sql
Normal file
@@ -0,0 +1,183 @@
|
||||
create extension if not exists pgcrypto;
|
||||
create extension if not exists citext;
|
||||
|
||||
create table if not exists roles (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
name text unique not null,
|
||||
description text not null default '',
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now()
|
||||
);
|
||||
|
||||
create table if not exists users (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
role_id uuid not null references roles(id),
|
||||
username citext unique not null,
|
||||
display_name text not null,
|
||||
email citext unique,
|
||||
password_hash text not null,
|
||||
is_active boolean not null default true,
|
||||
last_login_at timestamptz,
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now(),
|
||||
deleted_at timestamptz
|
||||
);
|
||||
|
||||
create table if not exists sessions (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
user_id uuid not null references users(id),
|
||||
ip_address inet,
|
||||
user_agent text,
|
||||
last_seen_at timestamptz not null default now(),
|
||||
expires_at timestamptz not null,
|
||||
created_at timestamptz not null default now(),
|
||||
revoked_at timestamptz
|
||||
);
|
||||
|
||||
create table if not exists refresh_tokens (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
session_id uuid not null references sessions(id),
|
||||
user_id uuid not null references users(id),
|
||||
token_hash text not null,
|
||||
expires_at timestamptz not null,
|
||||
created_at timestamptz not null default now(),
|
||||
revoked_at timestamptz
|
||||
);
|
||||
|
||||
create table if not exists gateways (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
name text unique not null,
|
||||
endpoint text not null,
|
||||
public_key text not null,
|
||||
listen_port integer not null,
|
||||
vpn_cidr cidr not null,
|
||||
dns_servers text[] not null default '{}',
|
||||
is_active boolean not null default true,
|
||||
last_sync_at timestamptz,
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now(),
|
||||
deleted_at timestamptz
|
||||
);
|
||||
|
||||
create table if not exists devices (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
user_id uuid not null references users(id),
|
||||
gateway_id uuid references gateways(id),
|
||||
name text not null,
|
||||
platform text not null,
|
||||
os_version text not null default '',
|
||||
app_version text not null default '',
|
||||
device_fingerprint text not null,
|
||||
public_key text not null,
|
||||
status text not null default 'active',
|
||||
last_seen_at timestamptz,
|
||||
last_connected_at timestamptz,
|
||||
approved_at timestamptz,
|
||||
revoked_at timestamptz,
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now(),
|
||||
deleted_at timestamptz
|
||||
);
|
||||
|
||||
create unique index if not exists idx_devices_user_fingerprint on devices(user_id, device_fingerprint) where deleted_at is null;
|
||||
|
||||
create table if not exists wireguard_peers (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
device_id uuid not null references devices(id),
|
||||
gateway_id uuid not null references gateways(id),
|
||||
public_key text unique not null,
|
||||
assigned_ip inet not null,
|
||||
preshared_key_ciphertext text,
|
||||
allowed_ips cidr[] not null default '{}',
|
||||
dns_servers text[] not null default '{}',
|
||||
profile_revision integer not null default 1,
|
||||
last_profile_issued_at timestamptz,
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now(),
|
||||
deleted_at timestamptz
|
||||
);
|
||||
|
||||
create table if not exists ip_allocations (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
gateway_id uuid not null references gateways(id),
|
||||
device_id uuid references devices(id),
|
||||
address inet not null,
|
||||
status text not null default 'allocated',
|
||||
allocated_at timestamptz not null default now(),
|
||||
released_at timestamptz,
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now()
|
||||
);
|
||||
|
||||
create unique index if not exists idx_ip_allocations_gateway_address on ip_allocations(gateway_id, address);
|
||||
create unique index if not exists idx_ip_allocations_device_active on ip_allocations(device_id) where status = 'allocated';
|
||||
|
||||
create table if not exists policies (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
name text unique not null,
|
||||
description text not null default '',
|
||||
priority integer not null default 100,
|
||||
effect text not null default 'allow',
|
||||
is_active boolean not null default true,
|
||||
full_tunnel boolean not null default false,
|
||||
created_by uuid references users(id),
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now(),
|
||||
deleted_at timestamptz
|
||||
);
|
||||
|
||||
create table if not exists policy_targets (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
policy_id uuid not null references policies(id),
|
||||
target_type text not null,
|
||||
target_id uuid not null,
|
||||
created_at timestamptz not null default now()
|
||||
);
|
||||
|
||||
create table if not exists policy_destinations (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
policy_id uuid not null references policies(id),
|
||||
destination cidr not null,
|
||||
description text not null default '',
|
||||
created_at timestamptz not null default now()
|
||||
);
|
||||
|
||||
create table if not exists groups (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
name text unique not null,
|
||||
description text not null default '',
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now(),
|
||||
deleted_at timestamptz
|
||||
);
|
||||
|
||||
create table if not exists group_memberships (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
group_id uuid not null references groups(id),
|
||||
user_id uuid not null references users(id),
|
||||
created_at timestamptz not null default now()
|
||||
);
|
||||
|
||||
create table if not exists audit_logs (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
actor_user_id uuid references users(id),
|
||||
actor_device_id uuid references devices(id),
|
||||
event_type text not null,
|
||||
entity_type text not null,
|
||||
entity_id uuid,
|
||||
status text not null,
|
||||
ip_address inet,
|
||||
message text not null,
|
||||
metadata jsonb not null default '{}'::jsonb,
|
||||
created_at timestamptz not null default now()
|
||||
);
|
||||
|
||||
create table if not exists settings (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
category text not null,
|
||||
key text not null,
|
||||
value jsonb not null,
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now(),
|
||||
unique(category, key)
|
||||
);
|
||||
24
backend/seed/001_seed.sql
Normal file
24
backend/seed/001_seed.sql
Normal file
@@ -0,0 +1,24 @@
|
||||
insert into roles (name, description)
|
||||
values
|
||||
('admin', 'NexaVPN administrator'),
|
||||
('user', 'Standard VPN user')
|
||||
on conflict (name) do nothing;
|
||||
|
||||
insert into settings (category, key, value)
|
||||
values
|
||||
('vpn', 'default_dns_servers', '["10.20.0.53"]'::jsonb),
|
||||
('auth', 'access_token_ttl_seconds', '900'::jsonb),
|
||||
('auth', 'refresh_token_ttl_seconds', '2592000'::jsonb)
|
||||
on conflict (category, key) do nothing;
|
||||
|
||||
insert into gateways (name, endpoint, public_key, listen_port, vpn_cidr, dns_servers, is_active)
|
||||
values (
|
||||
'primary-gateway',
|
||||
'vpn.example.com:51820',
|
||||
'replace-me-with-gateway-public-key',
|
||||
51820,
|
||||
'100.96.0.0/24',
|
||||
array['10.20.0.53'],
|
||||
true
|
||||
)
|
||||
on conflict (name) do nothing;
|
||||
Reference in New Issue
Block a user