User authentication with FastAPI

Adding user relationships to other tables

Want more?

This lesson for enrolled students only. Join the course to unlock it!

You can see the code changes implemented in this lecture below.

If you have purchased the course in a different platform, you still have access to the code changes per lecture here on Teclado. The lecture video and lecture notes remain locked.
Join course for $30

Modified files

storeapi/database.py
--- 
+++ 
@@ -9,7 +9,8 @@
     "posts",
     metadata,
     sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
-    sqlalchemy.Column("body", sqlalchemy.String)
+    sqlalchemy.Column("body", sqlalchemy.String),
+    sqlalchemy.Column("user_id", sqlalchemy.ForeignKey("users.id"), nullable=False),
 )

 user_table = sqlalchemy.Table(
@@ -20,12 +21,14 @@
     sqlalchemy.Column("password", sqlalchemy.String),
 )

+
 comment_table = sqlalchemy.Table(
     "comments",
     metadata,
     sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
     sqlalchemy.Column("body", sqlalchemy.String),
-    sqlalchemy.Column("post_id", sqlalchemy.ForeignKey("posts.id"), nullable=False)
+    sqlalchemy.Column("post_id", sqlalchemy.ForeignKey("posts.id"), nullable=False),
+    sqlalchemy.Column("user_id", sqlalchemy.ForeignKey("users.id"), nullable=False),
 )

 engine = sqlalchemy.create_engine(
storeapi/security.py
--- 
+++ 
@@ -1,25 +1,25 @@
 import datetime
+import logging
 from typing import Annotated
-import logging

 from fastapi import Depends, HTTPException, status
 from fastapi.security import OAuth2PasswordBearer
 from jose import ExpiredSignatureError, JWTError, jwt
 from passlib.context import CryptContext
+
 from storeapi.database import database, user_table

 logger = logging.getLogger(__name__)

 SECRET_KEY = "9b73f2a1bdd7ae163444473d29a6885ffa22ab26117068f72a5a56a74d12d1fc"
 ALGORITHM = "HS256"
-
 oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 pwd_context = CryptContext(schemes=["bcrypt"])

 credentials_exception = HTTPException(
     status_code=status.HTTP_401_UNAUTHORIZED,
     detail="Could not validate credentials",
-    headers={"WWW-Authenticate": "Bearer"},  # Tells client to authenticate with JWT
+    headers={"WWW-Authenticate": "Bearer"},
 )


@@ -29,11 +29,11 @@

 def create_access_token(email: str):
     logger.debug("Creating access token", extra={"email": email})
-    expire = datetime.datetime.utcnow() + datetime.timedelta(
+    expire = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
         minutes=access_token_expire_minutes()
     )
     jwt_data = {"sub": email, "exp": expire}
-    encoded_jwt = jwt.encode(jwt_data, SECRET_KEY, algorithm=ALGORITHM)
+    encoded_jwt = jwt.encode(jwt_data, key=SECRET_KEY, algorithm=ALGORITHM)
     return encoded_jwt


@@ -65,7 +65,7 @@

 async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
     try:
-        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+        payload = jwt.decode(token, key=SECRET_KEY, algorithms=[ALGORITHM])
         email = payload.get("sub")
         if email is None:
             raise credentials_exception
storeapi/logging_conf.py
--- 
+++ 
@@ -1,11 +1,10 @@
 import logging
 from logging.config import dictConfig

-from storeapi.config import DevConfig, config
+from storeapi.config import DevConfig, ProdConfig, config


-def obfuscated(email: str, obfuscated_length: int):
-    """Obfuscate email address for logging purposes."""
+def obfuscated(email: str, obfuscated_length: int) -> str:
     characters = email[:obfuscated_length]
     first, last = email.split("@")
     return characters + ("*" * (len(first) - obfuscated_length)) + "@" + last
@@ -23,7 +22,7 @@


 handlers = ["default", "rotating_file"]
-if config.ENV_STATE == "prod":
+if isinstance(config, ProdConfig):
     handlers = ["default", "rotating_file", "logtail"]


@@ -52,46 +51,43 @@
                 "file": {
                     "class": "pythonjsonlogger.jsonlogger.JsonFormatter",
                     "datefmt": "%Y-%m-%dT%H:%M:%S",
-                    # For JsonFormatter, the format string just defines what keys are included in the log record
-                    # It's a bit clunky, but it's the way to do it for now
                     "format": "%(asctime)s %(msecs)03d %(levelname)s %(correlation_id)s %(name)s %(lineno)d %(message)s",
                 },
             },
             "handlers": {
                 "default": {
-                    "class": "rich.logging.RichHandler",  # could use logging.StreamHandler instead
+                    "class": "rich.logging.RichHandler",
                     "level": "DEBUG",
                     "formatter": "console",
-                    "filters": ["correlation_id", "email_obfuscation"],
-                },
-                "logtail": {
-                    # https://betterstack.com/docs/logs/python/
-                    "class": "logtail.LogtailHandler",
-                    "level": "DEBUG",
-                    "formatter": "console",
-                    "filters": ["correlation_id", "email_obfuscation"],
-                    "source_token": config.LOGTAIL_API_KEY,  # gets passed to LogtailHandler constructor as kwargs
+                    "filters": ["correlation_id", "email_obfuscation"]
                 },
                 "rotating_file": {
                     "class": "logging.handlers.RotatingFileHandler",
                     "level": "DEBUG",
                     "formatter": "file",
+                    "filename": "storeapi.log",
+                    "maxBytes": 1024 * 1024,  # 1MB
+                    "backupCount": 5,
+                    "encoding": "utf8",
+                    "filters": ["correlation_id", "email_obfuscation"]
+                },
+                "logtail": {
+                    "class": "logtail.LogtailHandler",
+                    "level": "DEBUG",
+                    "formatter": "console",
                     "filters": ["correlation_id", "email_obfuscation"],
-                    "filename": "storeapi.log",
-                    "maxBytes": 1024 * 1024,  # 1 MB
-                    "backupCount": 2,
-                    "encoding": "utf8",
-                },
+                    "source_token": config.LOGTAIL_API_KEY
+                }
             },
             "loggers": {
                 "uvicorn": {"handlers": ["default", "rotating_file"], "level": "INFO"},
                 "storeapi": {
                     "handlers": handlers,
                     "level": "DEBUG" if isinstance(config, DevConfig) else "INFO",
-                    "propagate": False,
+                    "propagate": False
                 },
                 "databases": {"handlers": ["default"], "level": "WARNING"},
-                "aiosqlite": {"handlers": ["default"], "level": "WARNING"},
-            },
+                "aiosqlite": {"handlers": ["default"], "level": "WARNING"}
+            }
         }
-    )
+    )
storeapi/main.py
--- 
+++ 
@@ -4,6 +4,7 @@
 from asgi_correlation_id import CorrelationIdMiddleware
 from fastapi import FastAPI, HTTPException
 from fastapi.exception_handlers import http_exception_handler
+
 from storeapi.database import database
 from storeapi.logging_conf import configure_logging
 from storeapi.routers.post import router as post_router
@@ -20,13 +21,12 @@

 app = FastAPI(lifespan=lifespan)
 app.add_middleware(CorrelationIdMiddleware)
+
+app.include_router(post_router)
 app.include_router(user_router)
-app.include_router(post_router)


 @app.exception_handler(HTTPException)
 async def http_exception_handle_logging(request, exc):
     logger.error(f"HTTPException: {exc.status_code} {exc.detail}")
-    return await http_exception_handler(request, exc)
-
-
+    return await http_exception_handler(request, exc)
storeapi/routers/post.py
--- 
+++ 
@@ -30,12 +30,11 @@

 @router.post("/post", response_model=UserPost, status_code=201)
 async def create_post(
-    post: UserPostIn,
-    current_user: Annotated[User, Depends(get_current_user)],
+    post: UserPostIn, current_user: Annotated[User, Depends(get_current_user)]
 ):
     logger.info("Creating post")

-    data = post.model_dump()  # previously .dict()
+    data = {**post.model_dump(), "user_id": current_user.id}
     query = post_table.insert().values(data)

     logger.debug(query)
@@ -62,11 +61,10 @@
     logger.info("Creating comment")

     post = await find_post(comment.post_id)
-
     if not post:
         raise HTTPException(status_code=404, detail="Post not found")

-    data = comment.model_dump()  # previously .dict()
+    data = {**comment.model_dump(), "user_id": current_user.id}
     query = comment_table.insert().values(data)

     logger.debug(query)
@@ -91,7 +89,6 @@
     logger.info("Getting post and its comments")

     post = await find_post(post_id)
-
     if not post:
         raise HTTPException(status_code=404, detail="Post not found")
storeapi/routers/user.py
--- 
+++ 
@@ -1,6 +1,7 @@
 import logging

 from fastapi import APIRouter, HTTPException, status
+
 from storeapi.database import database, user_table
 from storeapi.models.user import UserIn
 from storeapi.security import (
storeapi/tests/test_security.py
--- 
+++ 
@@ -1,5 +1,6 @@
 import pytest
 from jose import jwt
+
 from storeapi import security


@@ -10,7 +11,7 @@
 def test_create_access_token():
     token = security.create_access_token("123")
     assert {"sub": "123"}.items() <= jwt.decode(
-        token, security.SECRET_KEY, algorithms=[security.ALGORITHM]
+        token, key=security.SECRET_KEY, algorithms=[security.ALGORITHM]
     ).items()


@@ -22,6 +23,7 @@
 @pytest.mark.anyio
 async def test_get_user(registered_user: dict):
     user = await security.get_user(registered_user["email"])
+
     assert user.email == registered_user["email"]


@@ -42,7 +44,7 @@
 @pytest.mark.anyio
 async def test_authenticate_user_not_found():
     with pytest.raises(security.HTTPException):
-        await security.authenticate_user("test@example.com", "1234")
+        await security.authenticate_user("test@example.net", "1234")


 @pytest.mark.anyio
storeapi/tests/conftest.py
--- 
+++ 
@@ -34,7 +34,7 @@


 @pytest.fixture()
-async def registered_user(async_client: AsyncClient):
+async def registered_user(async_client: AsyncClient) -> dict:
     user_details = {"email": "test@example.net", "password": "1234"}
     await async_client.post("/register", json=user_details)
     query = user_table.select().where(user_table.c.email == user_details["email"])
@@ -44,6 +44,6 @@


 @pytest.fixture()
-async def logged_in_token(async_client: AsyncClient, registered_user: dict):
+async def logged_in_token(async_client: AsyncClient, registered_user: dict) -> str:
     response = await async_client.post("/token", json=registered_user)
     return response.json()["access_token"]
storeapi/tests/routers/test_user.py
--- 
+++ 
@@ -16,7 +16,7 @@


 @pytest.mark.anyio
-async def test_register_already_exists(
+async def test_register_user_already_exists(
     async_client: AsyncClient, registered_user: dict
 ):
     response = await register_user(
storeapi/tests/routers/test_post.py
--- 
+++ 
@@ -1,5 +1,6 @@
 import pytest
 from httpx import AsyncClient
+
 from storeapi import security


@@ -40,7 +41,9 @@


 @pytest.mark.anyio
-async def test_create_post(async_client: AsyncClient, logged_in_token: str):
+async def test_create_post(
+    async_client: AsyncClient, registered_user: dict, logged_in_token: str
+):
     body = "Test Post"

     response = await async_client.post(
@@ -48,8 +51,13 @@
         json={"body": body},
         headers={"Authorization": f"Bearer {logged_in_token}"},
     )
+
     assert response.status_code == 201
-    assert {"id": 1, "body": body}.items() <= response.json().items()
+    assert {
+        "id": 1,
+        "body": body,
+        "user_id": registered_user["id"],
+    }.items() <= response.json().items()


 @pytest.mark.anyio
@@ -63,6 +71,7 @@
         json={"body": "Test Post"},
         headers={"Authorization": f"Bearer {token}"},
     )
+
     assert response.status_code == 401
     assert "Token has expired" in response.json()["detail"]

@@ -72,10 +81,9 @@
     async_client: AsyncClient, logged_in_token: str
 ):
     response = await async_client.post(
-        "/post",
-        json={},
-        headers={"Authorization": f"Bearer {logged_in_token}"},
+        "/post", json={}, headers={"Authorization": f"Bearer {logged_in_token}"}
     )
+
     assert response.status_code == 422


@@ -91,6 +99,7 @@
 async def test_create_comment(
     async_client: AsyncClient,
     created_post: dict,
+    registered_user: dict,
     logged_in_token: str,
 ):
     body = "Test Comment"
@@ -105,6 +114,7 @@
         "id": 1,
         "body": body,
         "post_id": created_post["id"],
+        "user_id": registered_user["id"],
     }.items() <= response.json().items()
storeapi/models/post.py
--- 
+++ 
@@ -9,6 +9,7 @@
     model_config = ConfigDict(from_attributes=True)

     id: int
+    user_id: int


 class CommentIn(BaseModel):
@@ -20,6 +21,7 @@
     model_config = ConfigDict(from_attributes=True)

     id: int
+    user_id: int


 class UserPostWithComments(BaseModel):