Fully implemented JWT request
This commit is contained in:
parent
b1527defc2
commit
1d84382621
|
|
@ -1,63 +1,118 @@
|
||||||
# auth/jwt.py
|
# backend/auth/jwt.py
|
||||||
|
"""
|
||||||
|
JWT utilities & FastAPI dependency.
|
||||||
|
|
||||||
|
• Extracts bearer token
|
||||||
|
• Verifies signature against SUPABASE_JWT_SECRET
|
||||||
|
• OPTIONAL — verifies/ignores audience
|
||||||
|
• Fetches user + tenant rows from Supabase via SQLAlchemy
|
||||||
|
"""
|
||||||
|
|
||||||
from fastapi import Request, HTTPException, status, Depends
|
from fastapi import Request, HTTPException, status, Depends
|
||||||
from jose import jwt, JWTError
|
from jose import jwt, JWTError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import text
|
||||||
import os
|
import os
|
||||||
|
|
||||||
JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET")
|
# --------------------------------------------------------------------------- #
|
||||||
ALGORITHM = "HS256"
|
# Environment / constants
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET")
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
AUDIENCE = os.getenv("SUPABASE_JWT_AUD", "authenticated") # default Supabase aud
|
||||||
|
|
||||||
|
# Lazy-import to avoid circular refs
|
||||||
|
from db.session import get_db
|
||||||
|
|
||||||
|
|
||||||
def get_token_from_header(request: Request) -> str:
|
# --------------------------------------------------------------------------- #
|
||||||
auth_header = request.headers.get("Authorization")
|
# Helpers
|
||||||
if not auth_header or not auth_header.startswith("Bearer "):
|
# --------------------------------------------------------------------------- #
|
||||||
|
def _get_token_from_header(request: Request) -> str:
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
if not auth.startswith("Bearer "):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Missing or invalid Authorization header",
|
detail="Missing or invalid Authorization header",
|
||||||
)
|
)
|
||||||
return auth_header.split(" ")[1]
|
return auth.split(" ", 1)[1]
|
||||||
|
|
||||||
|
|
||||||
# def verify_jwt_token(token: str) -> dict:
|
def _verify_jwt_token(token: str) -> dict:
|
||||||
# try:
|
"""
|
||||||
# payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM])
|
Decode & validate JWT.
|
||||||
# return payload
|
|
||||||
# except JWTError as e:
|
🔸 Option A (recommended): verify the expected audience.
|
||||||
# raise HTTPException(
|
🔸 Option B: skip audience verification entirely.
|
||||||
# status_code=status.HTTP_403_FORBIDDEN,
|
"""
|
||||||
# detail=f"Invalid JWT token: {str(e)}",
|
|
||||||
# )
|
|
||||||
def verify_jwt_token(token: str) -> dict:
|
|
||||||
try:
|
try:
|
||||||
|
# ---- Option A: enforce expected audience ---------------------------
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token,
|
||||||
JWT_SECRET,
|
JWT_SECRET,
|
||||||
algorithms=[ALGORITHM],
|
algorithms=[ALGORITHM],
|
||||||
options={"verify_aud": False} # <- Disable audience verification
|
audience=AUDIENCE, # <-- satisfies "aud" claim
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ---- Option B (if you really don’t care) ---------------------------
|
||||||
|
# payload = jwt.decode(
|
||||||
|
# token,
|
||||||
|
# JWT_SECRET,
|
||||||
|
# algorithms=[ALGORITHM],
|
||||||
|
# options={"verify_aud": False},
|
||||||
|
# )
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
except JWTError as e:
|
except JWTError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Invalid JWT token: {str(e)}",
|
detail=f"Invalid JWT token: {exc}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(request: Request):
|
# --------------------------------------------------------------------------- #
|
||||||
token = get_token_from_header(request)
|
# Public dependency
|
||||||
payload = verify_jwt_token(token)
|
# --------------------------------------------------------------------------- #
|
||||||
|
def get_current_user(
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Inject with `Depends(get_current_user)` to access:
|
||||||
|
|
||||||
user_id = payload.get("sub")
|
{
|
||||||
if not user_id:
|
"user": users row (dict),
|
||||||
raise HTTPException(
|
"tenant": tenants row (dict | None),
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
"scopes": JWT “role” claim,
|
||||||
detail="JWT missing 'sub' field",
|
}
|
||||||
)
|
"""
|
||||||
print("🔐 JWT payload seen by API →", payload)
|
token = _get_token_from_header(request)
|
||||||
return {
|
payload = _verify_jwt_token(token)
|
||||||
"id": user_id,
|
|
||||||
"email": payload.get("email"),
|
# ---- look-up user ------------------------------------------------------ #
|
||||||
"role": payload.get("role"),
|
uid = payload.get("sub")
|
||||||
"tenant_id": payload.get("tenant_id", "unknown"),
|
if not uid:
|
||||||
|
raise HTTPException(status_code=403, detail="JWT missing 'sub' claim")
|
||||||
|
|
||||||
|
user_row = (
|
||||||
|
db.execute(text("SELECT * FROM users WHERE id = :uid"), {"uid": uid})
|
||||||
|
.mappings()
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not user_row:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
# ---- tenant (optional) ------------------------------------------------- #
|
||||||
|
tenant_row = None
|
||||||
|
if (tid := user_row.get("tenant_id")):
|
||||||
|
tenant_row = (
|
||||||
|
db.execute(text("SELECT * FROM tenants WHERE id = :tid"), {"tid": tid})
|
||||||
|
.mappings()
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user": user_row,
|
||||||
|
"tenant": tenant_row,
|
||||||
|
"scopes": payload.get("role", []),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlalchemy.pool import NullPool
|
from sqlalchemy.pool import NullPool
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
@ -28,3 +29,18 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
# Base class for ORM models
|
# Base class for ORM models
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
# ───────────────────────────────
|
||||||
|
# FastAPI helper (NEW, non-breaking)
|
||||||
|
# ───────────────────────────────
|
||||||
|
|
||||||
|
def get_db() -> Session:
|
||||||
|
"""
|
||||||
|
Yield a DB session for FastAPI dependencies.
|
||||||
|
Closes the session automatically after the request.
|
||||||
|
"""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
@ -25,11 +25,13 @@ def supabase_check():
|
||||||
|
|
||||||
@app.get("/me")
|
@app.get("/me")
|
||||||
def me(current: dict = Depends(get_current_user)):
|
def me(current: dict = Depends(get_current_user)):
|
||||||
|
usr = current["user"] # users table row
|
||||||
|
tenant = current["tenant"] # tenant row or None
|
||||||
return {
|
return {
|
||||||
"id": current["id"],
|
"id": usr["id"],
|
||||||
"email": current["email"],
|
"email": usr["email"],
|
||||||
"role": current["role"],
|
"role": current["scopes"],
|
||||||
"tenant_id": current["tenant_id"],
|
"tenant_id": tenant["id"] if tenant else "unknown",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue