119 lines
3.6 KiB
Python
119 lines
3.6 KiB
Python
# backend/auth/jwt.py
|
||
"""
|
||
JWT utilities & FastAPI dependency.
|
||
|
||
• Extracts bearer token
|
||
• Verifies signature against SUPABASE_JWT_SECRET
|
||
• OPTIONAL — verifies/ignores audience
|
||
• Fetches user + tenant rows from Supabase via SQLAlchemy
|
||
"""
|
||
|
||
from fastapi import Request, HTTPException, status, Depends
|
||
from jose import jwt, JWTError
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import text
|
||
import os
|
||
|
||
# --------------------------------------------------------------------------- #
|
||
# Environment / constants
|
||
# --------------------------------------------------------------------------- #
|
||
JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET")
|
||
ALGORITHM = "HS256"
|
||
AUDIENCE = os.getenv("SUPABASE_JWT_AUD", "authenticated") # default Supabase aud
|
||
|
||
# Lazy-import to avoid circular refs
|
||
from db.session import get_db
|
||
|
||
|
||
# --------------------------------------------------------------------------- #
|
||
# Helpers
|
||
# --------------------------------------------------------------------------- #
|
||
def _get_token_from_header(request: Request) -> str:
|
||
auth = request.headers.get("Authorization", "")
|
||
if not auth.startswith("Bearer "):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Missing or invalid Authorization header",
|
||
)
|
||
return auth.split(" ", 1)[1]
|
||
|
||
|
||
def _verify_jwt_token(token: str) -> dict:
|
||
"""
|
||
Decode & validate JWT.
|
||
|
||
🔸 Option A (recommended): verify the expected audience.
|
||
🔸 Option B: skip audience verification entirely.
|
||
"""
|
||
try:
|
||
# ---- Option A: enforce expected audience ---------------------------
|
||
payload = jwt.decode(
|
||
token,
|
||
JWT_SECRET,
|
||
algorithms=[ALGORITHM],
|
||
audience=AUDIENCE, # <-- satisfies "aud" claim
|
||
)
|
||
|
||
# ---- Option B (if you really don’t care) ---------------------------
|
||
# payload = jwt.decode(
|
||
# token,
|
||
# JWT_SECRET,
|
||
# algorithms=[ALGORITHM],
|
||
# options={"verify_aud": False},
|
||
# )
|
||
|
||
return payload
|
||
except JWTError as exc:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail=f"Invalid JWT token: {exc}",
|
||
)
|
||
|
||
|
||
# --------------------------------------------------------------------------- #
|
||
# Public dependency
|
||
# --------------------------------------------------------------------------- #
|
||
def get_current_user(
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
) -> dict:
|
||
"""
|
||
Inject with `Depends(get_current_user)` to access:
|
||
|
||
{
|
||
"user": users row (dict),
|
||
"tenant": tenants row (dict | None),
|
||
"scopes": JWT “role” claim,
|
||
}
|
||
"""
|
||
token = _get_token_from_header(request)
|
||
payload = _verify_jwt_token(token)
|
||
|
||
# ---- look-up user ------------------------------------------------------ #
|
||
uid = payload.get("sub")
|
||
if not uid:
|
||
raise HTTPException(status_code=403, detail="JWT missing 'sub' claim")
|
||
|
||
user_row = (
|
||
db.execute(text("SELECT * FROM users WHERE id = :uid"), {"uid": uid})
|
||
.mappings()
|
||
.first()
|
||
)
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="User not found")
|
||
|
||
# ---- tenant (optional) ------------------------------------------------- #
|
||
tenant_row = None
|
||
if (tid := user_row.get("tenant_id")):
|
||
tenant_row = (
|
||
db.execute(text("SELECT * FROM tenants WHERE id = :tid"), {"tid": tid})
|
||
.mappings()
|
||
.first()
|
||
)
|
||
|
||
return {
|
||
"user": user_row,
|
||
"tenant": tenant_row,
|
||
"scopes": payload.get("role", []),
|
||
}
|