diff --git a/backend/Dockerfile b/backend/Dockerfile index 68e8166..7c6f9f7 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,8 +1,10 @@ FROM golang:1.23-alpine AS builder WORKDIR /src COPY go.mod ./ -RUN go mod download +COPY go.sum* ./ COPY . . +RUN go mod tidy +RUN go mod download RUN CGO_ENABLED=0 GOOS=linux go build -o /out/nexavpn-api ./cmd/api FROM alpine:3.21 diff --git a/backend/internal/auth/handler.go b/backend/internal/auth/handler.go index 91f9033..5e3198f 100644 --- a/backend/internal/auth/handler.go +++ b/backend/internal/auth/handler.go @@ -6,7 +6,7 @@ import ( "github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/audit" - "github.com/nexavpn/nexavpn/backend/internal/httpserver" + "github.com/nexavpn/nexavpn/backend/internal/requestctx" ) type Handler struct { @@ -110,7 +110,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { return } - if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { + if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok { _ = h.audit.Record(r.Context(), audit.Entry{ ActorUserID: &claims.UserID, EventType: "auth.logout", @@ -123,7 +123,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { } func (h *Handler) Me(w http.ResponseWriter, r *http.Request) { - claims, ok := httpserver.ClaimsFromContext(r.Context()) + claims, ok := requestctx.ClaimsFromContext(r.Context()) if !ok { apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") return diff --git a/backend/internal/device/handler.go b/backend/internal/device/handler.go index 2a6e1f4..48adfc7 100644 --- a/backend/internal/device/handler.go +++ b/backend/internal/device/handler.go @@ -9,7 +9,7 @@ import ( "github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/audit" - "github.com/nexavpn/nexavpn/backend/internal/httpserver" + "github.com/nexavpn/nexavpn/backend/internal/requestctx" ) type Handler struct { @@ -28,7 +28,7 @@ func (h *Handler) Enroll(w http.ResponseWriter, r *http.Request) { return } - userID, ok := httpserver.MustUserID(r.Context()) + userID, ok := requestctx.MustUserID(r.Context()) if !ok { apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") return @@ -56,7 +56,7 @@ func (h *Handler) Enroll(w http.ResponseWriter, r *http.Request) { } func (h *Handler) ListOwn(w http.ResponseWriter, r *http.Request) { - userID, ok := httpserver.MustUserID(r.Context()) + userID, ok := requestctx.MustUserID(r.Context()) if !ok { apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") return @@ -82,7 +82,7 @@ func (h *Handler) ListAll(w http.ResponseWriter, r *http.Request) { } func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) { - userID, ok := httpserver.MustUserID(r.Context()) + userID, ok := requestctx.MustUserID(r.Context()) if !ok { apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") return @@ -98,7 +98,7 @@ func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) { } func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) { - userID, ok := httpserver.MustUserID(r.Context()) + userID, ok := requestctx.MustUserID(r.Context()) if !ok { apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") return @@ -139,7 +139,7 @@ func (h *Handler) Revoke(w http.ResponseWriter, r *http.Request) { apiutil.Error(w, http.StatusInternalServerError, "device_revoke_failed", "unable to revoke device") return } - if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { + if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok { _ = h.audit.Record(r.Context(), audit.Entry{ ActorUserID: &claims.UserID, EntityType: "device", @@ -162,7 +162,7 @@ func (h *Handler) Rotate(w http.ResponseWriter, r *http.Request) { apiutil.Error(w, http.StatusInternalServerError, "device_rotate_failed", "unable to rotate device profile") return } - if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { + if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok { _ = h.audit.Record(r.Context(), audit.Entry{ ActorUserID: &claims.UserID, EntityType: "device", diff --git a/backend/internal/httpserver/middleware.go b/backend/internal/httpserver/middleware.go index b5dd746..51521c7 100644 --- a/backend/internal/httpserver/middleware.go +++ b/backend/internal/httpserver/middleware.go @@ -1,21 +1,16 @@ package httpserver import ( - "context" "net/http" "strings" "github.com/go-chi/chi/v5/middleware" - "github.com/google/uuid" "github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/auth" + "github.com/nexavpn/nexavpn/backend/internal/requestctx" ) -type contextKey string - -const claimsContextKey contextKey = "claims" - func BaseMiddleware(next http.Handler) http.Handler { return middleware.RealIP(middleware.RequestID(middleware.Logger(next))) } @@ -35,7 +30,7 @@ func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler { return } - ctx := context.WithValue(r.Context(), claimsContextKey, claims) + ctx := requestctx.WithClaims(r.Context(), claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -43,7 +38,7 @@ func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler { func AdminOnly(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, ok := ClaimsFromContext(r.Context()) + claims, ok := requestctx.ClaimsFromContext(r.Context()) if !ok || claims.Role != "admin" { apiutil.Error(w, http.StatusForbidden, "forbidden", "admin role required") return @@ -51,16 +46,3 @@ func AdminOnly(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } - -func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) { - claims, ok := ctx.Value(claimsContextKey).(auth.Claims) - return claims, ok -} - -func MustUserID(ctx context.Context) (uuid.UUID, bool) { - claims, ok := ClaimsFromContext(ctx) - if !ok { - return uuid.Nil, false - } - return claims.UserID, true -} diff --git a/backend/internal/policy/handler.go b/backend/internal/policy/handler.go index 6c1578a..bbebdb8 100644 --- a/backend/internal/policy/handler.go +++ b/backend/internal/policy/handler.go @@ -6,7 +6,7 @@ import ( "github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/audit" - "github.com/nexavpn/nexavpn/backend/internal/httpserver" + "github.com/nexavpn/nexavpn/backend/internal/requestctx" ) type Handler struct { @@ -35,7 +35,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request) { return } - claims, ok := httpserver.ClaimsFromContext(r.Context()) + claims, ok := requestctx.ClaimsFromContext(r.Context()) if !ok { apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") return diff --git a/backend/internal/requestctx/context.go b/backend/internal/requestctx/context.go new file mode 100644 index 0000000..a2c56eb --- /dev/null +++ b/backend/internal/requestctx/context.go @@ -0,0 +1,30 @@ +package requestctx + +import ( + "context" + + "github.com/google/uuid" + + "github.com/nexavpn/nexavpn/backend/internal/auth" +) + +type contextKey string + +const claimsKey contextKey = "claims" + +func WithClaims(ctx context.Context, claims auth.Claims) context.Context { + return context.WithValue(ctx, claimsKey, claims) +} + +func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) { + claims, ok := ctx.Value(claimsKey).(auth.Claims) + return claims, ok +} + +func MustUserID(ctx context.Context) (uuid.UUID, bool) { + claims, ok := ClaimsFromContext(ctx) + if !ok { + return uuid.Nil, false + } + return claims.UserID, true +} diff --git a/backend/internal/user/handler.go b/backend/internal/user/handler.go index 3b0a1a0..3f7082c 100644 --- a/backend/internal/user/handler.go +++ b/backend/internal/user/handler.go @@ -9,7 +9,7 @@ import ( "github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/audit" - "github.com/nexavpn/nexavpn/backend/internal/httpserver" + "github.com/nexavpn/nexavpn/backend/internal/requestctx" ) type Handler struct { @@ -47,7 +47,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request) { return } - if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { + if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok { _ = h.audit.Record(r.Context(), audit.Entry{ ActorUserID: &claims.UserID, EntityType: "user", @@ -73,7 +73,7 @@ func (h *Handler) Disable(w http.ResponseWriter, r *http.Request) { apiutil.Error(w, http.StatusBadRequest, "user_disable_failed", "unable to disable user") return } - if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { + if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok { _ = h.audit.Record(r.Context(), audit.Entry{ ActorUserID: &claims.UserID, EntityType: "user", @@ -96,7 +96,7 @@ func (h *Handler) Enable(w http.ResponseWriter, r *http.Request) { apiutil.Error(w, http.StatusBadRequest, "user_enable_failed", "unable to enable user") return } - if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { + if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok { _ = h.audit.Record(r.Context(), audit.Entry{ ActorUserID: &claims.UserID, EntityType: "user",