refactor: extract request context utilities into dedicated package

Move ClaimsFromContext and MustUserID helpers from httpserver to new requestctx package for better separation of concerns. Update all imports across auth, device, policy, and user handlers. Fix Dockerfile to copy go.sum and run go mod tidy before download.
This commit is contained in:
2026-03-15 16:37:01 +01:00
parent 830491cb0d
commit 298d301ce8
7 changed files with 52 additions and 38 deletions

View File

@@ -1,8 +1,10 @@
FROM golang:1.23-alpine AS builder FROM golang:1.23-alpine AS builder
WORKDIR /src WORKDIR /src
COPY go.mod ./ COPY go.mod ./
RUN go mod download COPY go.sum* ./
COPY . . COPY . .
RUN go mod tidy
RUN go mod download
RUN CGO_ENABLED=0 GOOS=linux go build -o /out/nexavpn-api ./cmd/api RUN CGO_ENABLED=0 GOOS=linux go build -o /out/nexavpn-api ./cmd/api
FROM alpine:3.21 FROM alpine:3.21

View File

@@ -6,7 +6,7 @@ import (
"github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/apiutil"
"github.com/nexavpn/nexavpn/backend/internal/audit" "github.com/nexavpn/nexavpn/backend/internal/audit"
"github.com/nexavpn/nexavpn/backend/internal/httpserver" "github.com/nexavpn/nexavpn/backend/internal/requestctx"
) )
type Handler struct { type Handler struct {
@@ -110,7 +110,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
return return
} }
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
_ = h.audit.Record(r.Context(), audit.Entry{ _ = h.audit.Record(r.Context(), audit.Entry{
ActorUserID: &claims.UserID, ActorUserID: &claims.UserID,
EventType: "auth.logout", EventType: "auth.logout",
@@ -123,7 +123,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) Me(w http.ResponseWriter, r *http.Request) { func (h *Handler) Me(w http.ResponseWriter, r *http.Request) {
claims, ok := httpserver.ClaimsFromContext(r.Context()) claims, ok := requestctx.ClaimsFromContext(r.Context())
if !ok { if !ok {
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
return return

View File

@@ -9,7 +9,7 @@ import (
"github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/apiutil"
"github.com/nexavpn/nexavpn/backend/internal/audit" "github.com/nexavpn/nexavpn/backend/internal/audit"
"github.com/nexavpn/nexavpn/backend/internal/httpserver" "github.com/nexavpn/nexavpn/backend/internal/requestctx"
) )
type Handler struct { type Handler struct {
@@ -28,7 +28,7 @@ func (h *Handler) Enroll(w http.ResponseWriter, r *http.Request) {
return return
} }
userID, ok := httpserver.MustUserID(r.Context()) userID, ok := requestctx.MustUserID(r.Context())
if !ok { if !ok {
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
return return
@@ -56,7 +56,7 @@ func (h *Handler) Enroll(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) ListOwn(w http.ResponseWriter, r *http.Request) { func (h *Handler) ListOwn(w http.ResponseWriter, r *http.Request) {
userID, ok := httpserver.MustUserID(r.Context()) userID, ok := requestctx.MustUserID(r.Context())
if !ok { if !ok {
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
return return
@@ -82,7 +82,7 @@ func (h *Handler) ListAll(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) { func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) {
userID, ok := httpserver.MustUserID(r.Context()) userID, ok := requestctx.MustUserID(r.Context())
if !ok { if !ok {
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
return return
@@ -98,7 +98,7 @@ func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) {
} }
func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) {
userID, ok := httpserver.MustUserID(r.Context()) userID, ok := requestctx.MustUserID(r.Context())
if !ok { if !ok {
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
return return
@@ -139,7 +139,7 @@ func (h *Handler) Revoke(w http.ResponseWriter, r *http.Request) {
apiutil.Error(w, http.StatusInternalServerError, "device_revoke_failed", "unable to revoke device") apiutil.Error(w, http.StatusInternalServerError, "device_revoke_failed", "unable to revoke device")
return return
} }
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
_ = h.audit.Record(r.Context(), audit.Entry{ _ = h.audit.Record(r.Context(), audit.Entry{
ActorUserID: &claims.UserID, ActorUserID: &claims.UserID,
EntityType: "device", EntityType: "device",
@@ -162,7 +162,7 @@ func (h *Handler) Rotate(w http.ResponseWriter, r *http.Request) {
apiutil.Error(w, http.StatusInternalServerError, "device_rotate_failed", "unable to rotate device profile") apiutil.Error(w, http.StatusInternalServerError, "device_rotate_failed", "unable to rotate device profile")
return return
} }
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
_ = h.audit.Record(r.Context(), audit.Entry{ _ = h.audit.Record(r.Context(), audit.Entry{
ActorUserID: &claims.UserID, ActorUserID: &claims.UserID,
EntityType: "device", EntityType: "device",

View File

@@ -1,21 +1,16 @@
package httpserver package httpserver
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/apiutil"
"github.com/nexavpn/nexavpn/backend/internal/auth" "github.com/nexavpn/nexavpn/backend/internal/auth"
"github.com/nexavpn/nexavpn/backend/internal/requestctx"
) )
type contextKey string
const claimsContextKey contextKey = "claims"
func BaseMiddleware(next http.Handler) http.Handler { func BaseMiddleware(next http.Handler) http.Handler {
return middleware.RealIP(middleware.RequestID(middleware.Logger(next))) return middleware.RealIP(middleware.RequestID(middleware.Logger(next)))
} }
@@ -35,7 +30,7 @@ func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler {
return return
} }
ctx := context.WithValue(r.Context(), claimsContextKey, claims) ctx := requestctx.WithClaims(r.Context(), claims)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
@@ -43,7 +38,7 @@ func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler {
func AdminOnly(next http.Handler) http.Handler { func AdminOnly(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := ClaimsFromContext(r.Context()) claims, ok := requestctx.ClaimsFromContext(r.Context())
if !ok || claims.Role != "admin" { if !ok || claims.Role != "admin" {
apiutil.Error(w, http.StatusForbidden, "forbidden", "admin role required") apiutil.Error(w, http.StatusForbidden, "forbidden", "admin role required")
return return
@@ -51,16 +46,3 @@ func AdminOnly(next http.Handler) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
claims, ok := ctx.Value(claimsContextKey).(auth.Claims)
return claims, ok
}
func MustUserID(ctx context.Context) (uuid.UUID, bool) {
claims, ok := ClaimsFromContext(ctx)
if !ok {
return uuid.Nil, false
}
return claims.UserID, true
}

View File

@@ -6,7 +6,7 @@ import (
"github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/apiutil"
"github.com/nexavpn/nexavpn/backend/internal/audit" "github.com/nexavpn/nexavpn/backend/internal/audit"
"github.com/nexavpn/nexavpn/backend/internal/httpserver" "github.com/nexavpn/nexavpn/backend/internal/requestctx"
) )
type Handler struct { type Handler struct {
@@ -35,7 +35,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request) {
return return
} }
claims, ok := httpserver.ClaimsFromContext(r.Context()) claims, ok := requestctx.ClaimsFromContext(r.Context())
if !ok { if !ok {
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims") apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
return return

View File

@@ -0,0 +1,30 @@
package requestctx
import (
"context"
"github.com/google/uuid"
"github.com/nexavpn/nexavpn/backend/internal/auth"
)
type contextKey string
const claimsKey contextKey = "claims"
func WithClaims(ctx context.Context, claims auth.Claims) context.Context {
return context.WithValue(ctx, claimsKey, claims)
}
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
claims, ok := ctx.Value(claimsKey).(auth.Claims)
return claims, ok
}
func MustUserID(ctx context.Context) (uuid.UUID, bool) {
claims, ok := ClaimsFromContext(ctx)
if !ok {
return uuid.Nil, false
}
return claims.UserID, true
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/nexavpn/nexavpn/backend/internal/apiutil" "github.com/nexavpn/nexavpn/backend/internal/apiutil"
"github.com/nexavpn/nexavpn/backend/internal/audit" "github.com/nexavpn/nexavpn/backend/internal/audit"
"github.com/nexavpn/nexavpn/backend/internal/httpserver" "github.com/nexavpn/nexavpn/backend/internal/requestctx"
) )
type Handler struct { type Handler struct {
@@ -47,7 +47,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request) {
return return
} }
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
_ = h.audit.Record(r.Context(), audit.Entry{ _ = h.audit.Record(r.Context(), audit.Entry{
ActorUserID: &claims.UserID, ActorUserID: &claims.UserID,
EntityType: "user", EntityType: "user",
@@ -73,7 +73,7 @@ func (h *Handler) Disable(w http.ResponseWriter, r *http.Request) {
apiutil.Error(w, http.StatusBadRequest, "user_disable_failed", "unable to disable user") apiutil.Error(w, http.StatusBadRequest, "user_disable_failed", "unable to disable user")
return return
} }
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
_ = h.audit.Record(r.Context(), audit.Entry{ _ = h.audit.Record(r.Context(), audit.Entry{
ActorUserID: &claims.UserID, ActorUserID: &claims.UserID,
EntityType: "user", EntityType: "user",
@@ -96,7 +96,7 @@ func (h *Handler) Enable(w http.ResponseWriter, r *http.Request) {
apiutil.Error(w, http.StatusBadRequest, "user_enable_failed", "unable to enable user") apiutil.Error(w, http.StatusBadRequest, "user_enable_failed", "unable to enable user")
return return
} }
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok { if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
_ = h.audit.Record(r.Context(), audit.Entry{ _ = h.audit.Record(r.Context(), audit.Entry{
ActorUserID: &claims.UserID, ActorUserID: &claims.UserID,
EntityType: "user", EntityType: "user",