From 1d843826217550f037b34c17f98b3bb42d06b6a1 Mon Sep 17 00:00:00 2001 From: Mike Kell Date: Thu, 24 Jul 2025 17:19:45 -0400 Subject: [PATCH] Fully implemented JWT request --- backend/auth/jwt.py | 125 ++++++++++++++++++++++++++++++------------ backend/db/session.py | 16 ++++++ backend/main.py | 10 ++-- 3 files changed, 112 insertions(+), 39 deletions(-) diff --git a/backend/auth/jwt.py b/backend/auth/jwt.py index fcf6b20..6c73132 100644 --- a/backend/auth/jwt.py +++ b/backend/auth/jwt.py @@ -1,63 +1,118 @@ -# auth/jwt.py +# backend/auth/jwt.py +""" +JWT utilities & FastAPI dependency. + +• Extracts bearer token +• Verifies signature against SUPABASE_JWT_SECRET +• OPTIONAL — verifies/ignores audience +• Fetches user + tenant rows from Supabase via SQLAlchemy +""" from fastapi import Request, HTTPException, status, Depends from jose import jwt, JWTError +from sqlalchemy.orm import Session +from sqlalchemy import text import os -JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET") -ALGORITHM = "HS256" +# --------------------------------------------------------------------------- # +# Environment / constants +# --------------------------------------------------------------------------- # +JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET") +ALGORITHM = "HS256" +AUDIENCE = os.getenv("SUPABASE_JWT_AUD", "authenticated") # default Supabase aud + +# Lazy-import to avoid circular refs +from db.session import get_db -def get_token_from_header(request: Request) -> str: - auth_header = request.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Bearer "): +# --------------------------------------------------------------------------- # +# Helpers +# --------------------------------------------------------------------------- # +def _get_token_from_header(request: Request) -> str: + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing or invalid Authorization header", ) - return auth_header.split(" ")[1] + return auth.split(" ", 1)[1] -# def verify_jwt_token(token: str) -> dict: -# try: -# payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM]) -# return payload -# except JWTError as e: -# raise HTTPException( -# status_code=status.HTTP_403_FORBIDDEN, -# detail=f"Invalid JWT token: {str(e)}", -# ) -def verify_jwt_token(token: str) -> dict: +def _verify_jwt_token(token: str) -> dict: + """ + Decode & validate JWT. + + 🔸 Option A (recommended): verify the expected audience. + 🔸 Option B: skip audience verification entirely. + """ try: + # ---- Option A: enforce expected audience --------------------------- payload = jwt.decode( token, JWT_SECRET, algorithms=[ALGORITHM], - options={"verify_aud": False} # <- Disable audience verification + audience=AUDIENCE, # <-- satisfies "aud" claim ) + # ---- Option B (if you really don’t care) --------------------------- + # payload = jwt.decode( + # token, + # JWT_SECRET, + # algorithms=[ALGORITHM], + # options={"verify_aud": False}, + # ) + return payload - except JWTError as e: + except JWTError as exc: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Invalid JWT token: {str(e)}", + detail=f"Invalid JWT token: {exc}", ) -def get_current_user(request: Request): - token = get_token_from_header(request) - payload = verify_jwt_token(token) +# --------------------------------------------------------------------------- # +# Public dependency +# --------------------------------------------------------------------------- # +def get_current_user( + request: Request, + db: Session = Depends(get_db), +) -> dict: + """ + Inject with `Depends(get_current_user)` to access: - user_id = payload.get("sub") - if not user_id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="JWT missing 'sub' field", - ) - print("🔐 JWT payload seen by API →", payload) - return { - "id": user_id, - "email": payload.get("email"), - "role": payload.get("role"), - "tenant_id": payload.get("tenant_id", "unknown"), + { + "user": users row (dict), + "tenant": tenants row (dict | None), + "scopes": JWT “role” claim, + } + """ + token = _get_token_from_header(request) + payload = _verify_jwt_token(token) + + # ---- look-up user ------------------------------------------------------ # + uid = payload.get("sub") + if not uid: + raise HTTPException(status_code=403, detail="JWT missing 'sub' claim") + + user_row = ( + db.execute(text("SELECT * FROM users WHERE id = :uid"), {"uid": uid}) + .mappings() + .first() + ) + if not user_row: + raise HTTPException(status_code=404, detail="User not found") + + # ---- tenant (optional) ------------------------------------------------- # + tenant_row = None + if (tid := user_row.get("tenant_id")): + tenant_row = ( + db.execute(text("SELECT * FROM tenants WHERE id = :tid"), {"tid": tid}) + .mappings() + .first() + ) + + return { + "user": user_row, + "tenant": tenant_row, + "scopes": payload.get("role", []), } diff --git a/backend/db/session.py b/backend/db/session.py index f70d36f..40e4762 100644 --- a/backend/db/session.py +++ b/backend/db/session.py @@ -4,6 +4,7 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool +from sqlalchemy.orm import Session from dotenv import load_dotenv import os @@ -28,3 +29,18 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Base class for ORM models Base = declarative_base() + +# ─────────────────────────────── +# FastAPI helper (NEW, non-breaking) +# ─────────────────────────────── + +def get_db() -> Session: + """ + Yield a DB session for FastAPI dependencies. + Closes the session automatically after the request. + """ + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index 69c51cb..884e11c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -25,11 +25,13 @@ def supabase_check(): @app.get("/me") def me(current: dict = Depends(get_current_user)): + usr = current["user"] # users table row + tenant = current["tenant"] # tenant row or None return { - "id": current["id"], - "email": current["email"], - "role": current["role"], - "tenant_id": current["tenant_id"], + "id": usr["id"], + "email": usr["email"], + "role": current["scopes"], + "tenant_id": tenant["id"] if tenant else "unknown", }