import logging import time from collections import defaultdict, deque from collections.abc import Callable from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware from app.api.router import api_router from app.core.config import get_settings from app.db.session import Base, engine from app.models import * # noqa: F403 settings = get_settings() logging.basicConfig(level=settings.log_level) class SecurityMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: Callable) -> Response: unsafe = request.method in {"POST", "PUT", "PATCH", "DELETE"} if unsafe and request.url.path.startswith("/api/"): csrf_cookie = request.cookies.get("np_csrf") csrf_header = request.headers.get("x-csrf-token") exempt = request.url.path in {"/api/auth/login", "/api/setup/complete", "/api/auth/accept-invite"} if not exempt and csrf_cookie != csrf_header: return Response("CSRF validation failed", status_code=403) response = await call_next(request) response.headers.setdefault("X-Content-Type-Options", "nosniff") response.headers.setdefault("X-Frame-Options", "DENY") response.headers.setdefault("Referrer-Policy", "strict-origin-when-cross-origin") response.headers.setdefault("Permissions-Policy", "camera=(self), geolocation=(), microphone=()") return response class RateLimitMiddleware(BaseHTTPMiddleware): buckets: dict[str, deque[float]] = defaultdict(deque) async def dispatch(self, request: Request, call_next: Callable) -> Response: limited_paths = ("/api/auth/login", "/api/auth/accept-invite", "/api/admin/users") if request.url.path.startswith(limited_paths): key = f"{request.client.host if request.client else 'unknown'}:{request.url.path}" now = time.time() bucket = self.buckets[key] while bucket and now - bucket[0] > 60: bucket.popleft() if len(bucket) >= 10: return Response("Too many requests", status_code=429) bucket.append(now) return await call_next(request) app = FastAPI(title="NexaPantry API", version="0.1.0") app.add_middleware(SecurityMiddleware) app.add_middleware(RateLimitMiddleware) app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "X-CSRF-Token"], ) @app.on_event("startup") def on_startup() -> None: Base.metadata.create_all(bind=engine) @app.get("/healthz") def healthz() -> dict: return {"status": "ok"} app.include_router(api_router, prefix="/api")