Fully implemented JWT request

This commit is contained in:
Mike Kell 2025-07-24 17:19:45 -04:00
parent b1527defc2
commit 1d84382621
3 changed files with 112 additions and 39 deletions

View File

@ -1,63 +1,118 @@
# auth/jwt.py
# backend/auth/jwt.py
"""
JWT utilities & FastAPI dependency.
Extracts bearer token
Verifies signature against SUPABASE_JWT_SECRET
OPTIONAL verifies/ignores audience
Fetches user + tenant rows from Supabase via SQLAlchemy
"""
from fastapi import Request, HTTPException, status, Depends
from jose import jwt, JWTError
from sqlalchemy.orm import Session
from sqlalchemy import text
import os
JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET")
ALGORITHM = "HS256"
# --------------------------------------------------------------------------- #
# Environment / constants
# --------------------------------------------------------------------------- #
JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET")
ALGORITHM = "HS256"
AUDIENCE = os.getenv("SUPABASE_JWT_AUD", "authenticated") # default Supabase aud
# Lazy-import to avoid circular refs
from db.session import get_db
def get_token_from_header(request: Request) -> str:
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
# --------------------------------------------------------------------------- #
# Helpers
# --------------------------------------------------------------------------- #
def _get_token_from_header(request: Request) -> str:
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid Authorization header",
)
return auth_header.split(" ")[1]
return auth.split(" ", 1)[1]
# def verify_jwt_token(token: str) -> dict:
# try:
# payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM])
# return payload
# except JWTError as e:
# raise HTTPException(
# status_code=status.HTTP_403_FORBIDDEN,
# detail=f"Invalid JWT token: {str(e)}",
# )
def verify_jwt_token(token: str) -> dict:
def _verify_jwt_token(token: str) -> dict:
"""
Decode & validate JWT.
🔸 Option A (recommended): verify the expected audience.
🔸 Option B: skip audience verification entirely.
"""
try:
# ---- Option A: enforce expected audience ---------------------------
payload = jwt.decode(
token,
JWT_SECRET,
algorithms=[ALGORITHM],
options={"verify_aud": False} # <- Disable audience verification
audience=AUDIENCE, # <-- satisfies "aud" claim
)
# ---- Option B (if you really dont care) ---------------------------
# payload = jwt.decode(
# token,
# JWT_SECRET,
# algorithms=[ALGORITHM],
# options={"verify_aud": False},
# )
return payload
except JWTError as e:
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Invalid JWT token: {str(e)}",
detail=f"Invalid JWT token: {exc}",
)
def get_current_user(request: Request):
token = get_token_from_header(request)
payload = verify_jwt_token(token)
# --------------------------------------------------------------------------- #
# Public dependency
# --------------------------------------------------------------------------- #
def get_current_user(
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
Inject with `Depends(get_current_user)` to access:
user_id = payload.get("sub")
if not user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="JWT missing 'sub' field",
)
print("🔐 JWT payload seen by API →", payload)
return {
"id": user_id,
"email": payload.get("email"),
"role": payload.get("role"),
"tenant_id": payload.get("tenant_id", "unknown"),
{
"user": users row (dict),
"tenant": tenants row (dict | None),
"scopes": JWT role claim,
}
"""
token = _get_token_from_header(request)
payload = _verify_jwt_token(token)
# ---- look-up user ------------------------------------------------------ #
uid = payload.get("sub")
if not uid:
raise HTTPException(status_code=403, detail="JWT missing 'sub' claim")
user_row = (
db.execute(text("SELECT * FROM users WHERE id = :uid"), {"uid": uid})
.mappings()
.first()
)
if not user_row:
raise HTTPException(status_code=404, detail="User not found")
# ---- tenant (optional) ------------------------------------------------- #
tenant_row = None
if (tid := user_row.get("tenant_id")):
tenant_row = (
db.execute(text("SELECT * FROM tenants WHERE id = :tid"), {"tid": tid})
.mappings()
.first()
)
return {
"user": user_row,
"tenant": tenant_row,
"scopes": payload.get("role", []),
}

View File

@ -4,6 +4,7 @@ from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import Session
from dotenv import load_dotenv
import os
@ -28,3 +29,18 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base class for ORM models
Base = declarative_base()
# ───────────────────────────────
# FastAPI helper (NEW, non-breaking)
# ───────────────────────────────
def get_db() -> Session:
"""
Yield a DB session for FastAPI dependencies.
Closes the session automatically after the request.
"""
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@ -25,11 +25,13 @@ def supabase_check():
@app.get("/me")
def me(current: dict = Depends(get_current_user)):
usr = current["user"] # users table row
tenant = current["tenant"] # tenant row or None
return {
"id": current["id"],
"email": current["email"],
"role": current["role"],
"tenant_id": current["tenant_id"],
"id": usr["id"],
"email": usr["email"],
"role": current["scopes"],
"tenant_id": tenant["id"] if tenant else "unknown",
}