refactor: extract request context utilities into dedicated package
Move ClaimsFromContext and MustUserID helpers from httpserver to new requestctx package for better separation of concerns. Update all imports across auth, device, policy, and user handlers. Fix Dockerfile to copy go.sum and run go mod tidy before download.
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
FROM golang:1.23-alpine AS builder
|
FROM golang:1.23-alpine AS builder
|
||||||
WORKDIR /src
|
WORKDIR /src
|
||||||
COPY go.mod ./
|
COPY go.mod ./
|
||||||
RUN go mod download
|
COPY go.sum* ./
|
||||||
COPY . .
|
COPY . .
|
||||||
|
RUN go mod tidy
|
||||||
|
RUN go mod download
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /out/nexavpn-api ./cmd/api
|
RUN CGO_ENABLED=0 GOOS=linux go build -o /out/nexavpn-api ./cmd/api
|
||||||
|
|
||||||
FROM alpine:3.21
|
FROM alpine:3.21
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
"github.com/nexavpn/nexavpn/backend/internal/requestctx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
@@ -110,7 +110,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
|
||||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||||
ActorUserID: &claims.UserID,
|
ActorUserID: &claims.UserID,
|
||||||
EventType: "auth.logout",
|
EventType: "auth.logout",
|
||||||
@@ -123,7 +123,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) Me(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) Me(w http.ResponseWriter, r *http.Request) {
|
||||||
claims, ok := httpserver.ClaimsFromContext(r.Context())
|
claims, ok := requestctx.ClaimsFromContext(r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
"github.com/nexavpn/nexavpn/backend/internal/requestctx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
@@ -28,7 +28,7 @@ func (h *Handler) Enroll(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, ok := httpserver.MustUserID(r.Context())
|
userID, ok := requestctx.MustUserID(r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||||
return
|
return
|
||||||
@@ -56,7 +56,7 @@ func (h *Handler) Enroll(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) ListOwn(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) ListOwn(w http.ResponseWriter, r *http.Request) {
|
||||||
userID, ok := httpserver.MustUserID(r.Context())
|
userID, ok := requestctx.MustUserID(r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||||
return
|
return
|
||||||
@@ -82,7 +82,7 @@ func (h *Handler) ListAll(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
userID, ok := httpserver.MustUserID(r.Context())
|
userID, ok := requestctx.MustUserID(r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||||
return
|
return
|
||||||
@@ -98,7 +98,7 @@ func (h *Handler) ConnectionStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetOwnProfile(w http.ResponseWriter, r *http.Request) {
|
||||||
userID, ok := httpserver.MustUserID(r.Context())
|
userID, ok := requestctx.MustUserID(r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||||
return
|
return
|
||||||
@@ -139,7 +139,7 @@ func (h *Handler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||||||
apiutil.Error(w, http.StatusInternalServerError, "device_revoke_failed", "unable to revoke device")
|
apiutil.Error(w, http.StatusInternalServerError, "device_revoke_failed", "unable to revoke device")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
|
||||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||||
ActorUserID: &claims.UserID,
|
ActorUserID: &claims.UserID,
|
||||||
EntityType: "device",
|
EntityType: "device",
|
||||||
@@ -162,7 +162,7 @@ func (h *Handler) Rotate(w http.ResponseWriter, r *http.Request) {
|
|||||||
apiutil.Error(w, http.StatusInternalServerError, "device_rotate_failed", "unable to rotate device profile")
|
apiutil.Error(w, http.StatusInternalServerError, "device_rotate_failed", "unable to rotate device profile")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
|
||||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||||
ActorUserID: &claims.UserID,
|
ActorUserID: &claims.UserID,
|
||||||
EntityType: "device",
|
EntityType: "device",
|
||||||
|
|||||||
@@ -1,21 +1,16 @@
|
|||||||
package httpserver
|
package httpserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/google/uuid"
|
|
||||||
|
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/auth"
|
"github.com/nexavpn/nexavpn/backend/internal/auth"
|
||||||
|
"github.com/nexavpn/nexavpn/backend/internal/requestctx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey string
|
|
||||||
|
|
||||||
const claimsContextKey contextKey = "claims"
|
|
||||||
|
|
||||||
func BaseMiddleware(next http.Handler) http.Handler {
|
func BaseMiddleware(next http.Handler) http.Handler {
|
||||||
return middleware.RealIP(middleware.RequestID(middleware.Logger(next)))
|
return middleware.RealIP(middleware.RequestID(middleware.Logger(next)))
|
||||||
}
|
}
|
||||||
@@ -35,7 +30,7 @@ func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), claimsContextKey, claims)
|
ctx := requestctx.WithClaims(r.Context(), claims)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -43,7 +38,7 @@ func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler {
|
|||||||
|
|
||||||
func AdminOnly(next http.Handler) http.Handler {
|
func AdminOnly(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
claims, ok := ClaimsFromContext(r.Context())
|
claims, ok := requestctx.ClaimsFromContext(r.Context())
|
||||||
if !ok || claims.Role != "admin" {
|
if !ok || claims.Role != "admin" {
|
||||||
apiutil.Error(w, http.StatusForbidden, "forbidden", "admin role required")
|
apiutil.Error(w, http.StatusForbidden, "forbidden", "admin role required")
|
||||||
return
|
return
|
||||||
@@ -51,16 +46,3 @@ func AdminOnly(next http.Handler) http.Handler {
|
|||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
|
|
||||||
claims, ok := ctx.Value(claimsContextKey).(auth.Claims)
|
|
||||||
return claims, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func MustUserID(ctx context.Context) (uuid.UUID, bool) {
|
|
||||||
claims, ok := ClaimsFromContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return uuid.Nil, false
|
|
||||||
}
|
|
||||||
return claims.UserID, true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
"github.com/nexavpn/nexavpn/backend/internal/requestctx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
@@ -35,7 +35,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := httpserver.ClaimsFromContext(r.Context())
|
claims, ok := requestctx.ClaimsFromContext(r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
apiutil.Error(w, http.StatusUnauthorized, "unauthorized", "missing auth claims")
|
||||||
return
|
return
|
||||||
|
|||||||
30
backend/internal/requestctx/context.go
Normal file
30
backend/internal/requestctx/context.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package requestctx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"github.com/nexavpn/nexavpn/backend/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const claimsKey contextKey = "claims"
|
||||||
|
|
||||||
|
func WithClaims(ctx context.Context, claims auth.Claims) context.Context {
|
||||||
|
return context.WithValue(ctx, claimsKey, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
|
||||||
|
claims, ok := ctx.Value(claimsKey).(auth.Claims)
|
||||||
|
return claims, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustUserID(ctx context.Context) (uuid.UUID, bool) {
|
||||||
|
claims, ok := ClaimsFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return uuid.Nil, false
|
||||||
|
}
|
||||||
|
return claims.UserID, true
|
||||||
|
}
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
"github.com/nexavpn/nexavpn/backend/internal/apiutil"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
"github.com/nexavpn/nexavpn/backend/internal/audit"
|
||||||
"github.com/nexavpn/nexavpn/backend/internal/httpserver"
|
"github.com/nexavpn/nexavpn/backend/internal/requestctx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
@@ -47,7 +47,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
|
||||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||||
ActorUserID: &claims.UserID,
|
ActorUserID: &claims.UserID,
|
||||||
EntityType: "user",
|
EntityType: "user",
|
||||||
@@ -73,7 +73,7 @@ func (h *Handler) Disable(w http.ResponseWriter, r *http.Request) {
|
|||||||
apiutil.Error(w, http.StatusBadRequest, "user_disable_failed", "unable to disable user")
|
apiutil.Error(w, http.StatusBadRequest, "user_disable_failed", "unable to disable user")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
|
||||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||||
ActorUserID: &claims.UserID,
|
ActorUserID: &claims.UserID,
|
||||||
EntityType: "user",
|
EntityType: "user",
|
||||||
@@ -96,7 +96,7 @@ func (h *Handler) Enable(w http.ResponseWriter, r *http.Request) {
|
|||||||
apiutil.Error(w, http.StatusBadRequest, "user_enable_failed", "unable to enable user")
|
apiutil.Error(w, http.StatusBadRequest, "user_enable_failed", "unable to enable user")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if claims, ok := httpserver.ClaimsFromContext(r.Context()); ok {
|
if claims, ok := requestctx.ClaimsFromContext(r.Context()); ok {
|
||||||
_ = h.audit.Record(r.Context(), audit.Entry{
|
_ = h.audit.Record(r.Context(), audit.Entry{
|
||||||
ActorUserID: &claims.UserID,
|
ActorUserID: &claims.UserID,
|
||||||
EntityType: "user",
|
EntityType: "user",
|
||||||
|
|||||||
Reference in New Issue
Block a user