Fully implemented JWT request
This commit is contained in:
parent
b1527defc2
commit
1d84382621
|
|
@ -1,63 +1,118 @@
|
|||
# auth/jwt.py
|
||||
# backend/auth/jwt.py
|
||||
"""
|
||||
JWT utilities & FastAPI dependency.
|
||||
|
||||
• Extracts bearer token
|
||||
• Verifies signature against SUPABASE_JWT_SECRET
|
||||
• OPTIONAL — verifies/ignores audience
|
||||
• Fetches user + tenant rows from Supabase via SQLAlchemy
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException, status, Depends
|
||||
from jose import jwt, JWTError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
import os
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Environment / constants
|
||||
# --------------------------------------------------------------------------- #
|
||||
JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET")
|
||||
ALGORITHM = "HS256"
|
||||
AUDIENCE = os.getenv("SUPABASE_JWT_AUD", "authenticated") # default Supabase aud
|
||||
|
||||
# Lazy-import to avoid circular refs
|
||||
from db.session import get_db
|
||||
|
||||
|
||||
def get_token_from_header(request: Request) -> str:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Helpers
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _get_token_from_header(request: Request) -> str:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing or invalid Authorization header",
|
||||
)
|
||||
return auth_header.split(" ")[1]
|
||||
return auth.split(" ", 1)[1]
|
||||
|
||||
|
||||
# def verify_jwt_token(token: str) -> dict:
|
||||
# try:
|
||||
# payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM])
|
||||
# return payload
|
||||
# except JWTError as e:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_403_FORBIDDEN,
|
||||
# detail=f"Invalid JWT token: {str(e)}",
|
||||
# )
|
||||
def verify_jwt_token(token: str) -> dict:
|
||||
def _verify_jwt_token(token: str) -> dict:
|
||||
"""
|
||||
Decode & validate JWT.
|
||||
|
||||
🔸 Option A (recommended): verify the expected audience.
|
||||
🔸 Option B: skip audience verification entirely.
|
||||
"""
|
||||
try:
|
||||
# ---- Option A: enforce expected audience ---------------------------
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
JWT_SECRET,
|
||||
algorithms=[ALGORITHM],
|
||||
options={"verify_aud": False} # <- Disable audience verification
|
||||
audience=AUDIENCE, # <-- satisfies "aud" claim
|
||||
)
|
||||
|
||||
# ---- Option B (if you really don’t care) ---------------------------
|
||||
# payload = jwt.decode(
|
||||
# token,
|
||||
# JWT_SECRET,
|
||||
# algorithms=[ALGORITHM],
|
||||
# options={"verify_aud": False},
|
||||
# )
|
||||
|
||||
return payload
|
||||
except JWTError as e:
|
||||
except JWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Invalid JWT token: {str(e)}",
|
||||
detail=f"Invalid JWT token: {exc}",
|
||||
)
|
||||
|
||||
|
||||
def get_current_user(request: Request):
|
||||
token = get_token_from_header(request)
|
||||
payload = verify_jwt_token(token)
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Public dependency
|
||||
# --------------------------------------------------------------------------- #
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Inject with `Depends(get_current_user)` to access:
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="JWT missing 'sub' field",
|
||||
)
|
||||
print("🔐 JWT payload seen by API →", payload)
|
||||
return {
|
||||
"id": user_id,
|
||||
"email": payload.get("email"),
|
||||
"role": payload.get("role"),
|
||||
"tenant_id": payload.get("tenant_id", "unknown"),
|
||||
{
|
||||
"user": users row (dict),
|
||||
"tenant": tenants row (dict | None),
|
||||
"scopes": JWT “role” claim,
|
||||
}
|
||||
"""
|
||||
token = _get_token_from_header(request)
|
||||
payload = _verify_jwt_token(token)
|
||||
|
||||
# ---- look-up user ------------------------------------------------------ #
|
||||
uid = payload.get("sub")
|
||||
if not uid:
|
||||
raise HTTPException(status_code=403, detail="JWT missing 'sub' claim")
|
||||
|
||||
user_row = (
|
||||
db.execute(text("SELECT * FROM users WHERE id = :uid"), {"uid": uid})
|
||||
.mappings()
|
||||
.first()
|
||||
)
|
||||
if not user_row:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# ---- tenant (optional) ------------------------------------------------- #
|
||||
tenant_row = None
|
||||
if (tid := user_row.get("tenant_id")):
|
||||
tenant_row = (
|
||||
db.execute(text("SELECT * FROM tenants WHERE id = :tid"), {"tid": tid})
|
||||
.mappings()
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user_row,
|
||||
"tenant": tenant_row,
|
||||
"scopes": payload.get("role", []),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from sqlalchemy import create_engine
|
|||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.orm import Session
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
|
|
@ -28,3 +29,18 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|||
|
||||
# Base class for ORM models
|
||||
Base = declarative_base()
|
||||
|
||||
# ───────────────────────────────
|
||||
# FastAPI helper (NEW, non-breaking)
|
||||
# ───────────────────────────────
|
||||
|
||||
def get_db() -> Session:
|
||||
"""
|
||||
Yield a DB session for FastAPI dependencies.
|
||||
Closes the session automatically after the request.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
|
@ -25,11 +25,13 @@ def supabase_check():
|
|||
|
||||
@app.get("/me")
|
||||
def me(current: dict = Depends(get_current_user)):
|
||||
usr = current["user"] # users table row
|
||||
tenant = current["tenant"] # tenant row or None
|
||||
return {
|
||||
"id": current["id"],
|
||||
"email": current["email"],
|
||||
"role": current["role"],
|
||||
"tenant_id": current["tenant_id"],
|
||||
"id": usr["id"],
|
||||
"email": usr["email"],
|
||||
"role": current["scopes"],
|
||||
"tenant_id": tenant["id"] if tenant else "unknown",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue