Working with async databases

Using a database in our FastAPI router

Want more?

This lesson for enrolled students only. Join the course to unlock it!

You can see the code changes implemented in this lecture below.

If you have purchased the course in a different platform, you still have access to the code changes per lecture here on Teclado. The lecture video and lecture notes remain locked.
Join course for $30

New files

storeapi/routers/post.py
from fastapi import APIRouter, HTTPException
from storeapi.database import comment_table, database, post_table
from storeapi.models.post import (
    Comment,
    CommentIn,
    UserPost,
    UserPostIn,
    UserPostWithComments,
)

router = APIRouter()


async def find_post(post_id: int):
    query = post_table.select().where(post_table.c.id == post_id)
    return await database.fetch_one(query)


@router.post("/post", response_model=UserPost, status_code=201)
async def create_post(post: UserPostIn):
    data = post.model_dump()  # previously .dict()
    query = post_table.insert().values(data)
    last_record_id = await database.execute(query)
    return {**data, "id": last_record_id}


@router.get("/post", response_model=list[UserPost])
async def get_all_posts():
    query = post_table.select()
    return await database.fetch_all(query)


@router.post("/comment", response_model=Comment, status_code=201)
async def create_comment(comment: CommentIn):
    post = await find_post(comment.post_id)
    if not post:
        raise HTTPException(status_code=404, detail="Post not found")

    data = comment.model_dump()  # previously .dict()
    query = comment_table.insert().values(data)
    last_record_id = await database.execute(query)
    return {**data, "id": last_record_id}


@router.get("/post/{post_id}/comment", response_model=list[Comment])
async def get_comments_on_post(post_id: int):
    query = comment_table.select().where(comment_table.c.post_id == post_id)
    return await database.fetch_all(query)


@router.get("/post/{post_id}", response_model=UserPostWithComments)
async def get_post_with_comments(post_id: int):
    post = await find_post(post_id)
    if not post:
        raise HTTPException(status_code=404, detail="Post not found")

    return {
        "post": post,
        "comments": await get_comments_on_post(post_id),
    }
storeapi/tests/routers/test_post.py
import pytest
from httpx import AsyncClient


async def create_post(body: str, async_client: AsyncClient) -> dict:
    response = await async_client.post("/post", json={"body": body})
    return response.json()


async def create_comment(body: str, post_id: int, async_client: AsyncClient) -> dict:
    response = await async_client.post(
        "/comment", json={"body": body, "post_id": post_id}
    )
    return response.json()


@pytest.fixture()
async def created_post(async_client: AsyncClient):
    return await create_post("Test Post", async_client)


@pytest.fixture()
async def created_comment(async_client: AsyncClient, created_post: dict):
    return await create_comment("Test Comment", created_post["id"], async_client)


@pytest.mark.anyio
async def test_create_post(async_client: AsyncClient):
    body = "Test Post"

    response = await async_client.post("/post", json={"body": body})

    assert response.status_code == 201
    assert {"id": 1, "body": body}.items() <= response.json().items()


@pytest.mark.anyio
async def test_create_post_missing_data(async_client: AsyncClient):
    response = await async_client.post("/post", json={})

    assert response.status_code == 422


@pytest.mark.anyio
async def test_get_all_posts(async_client: AsyncClient, created_post: dict):
    response = await async_client.get("/post")

    assert response.status_code == 200
    assert response.json() == [created_post]


@pytest.mark.anyio
async def test_create_comment(async_client: AsyncClient, created_post: dict):
    body = "Test Comment"

    response = await async_client.post(
        "/comment", json={"body": body, "post_id": created_post["id"]}
    )
    assert response.status_code == 201
    assert {
        "id": 1,
        "body": body,
        "post_id": created_post["id"],
    }.items() <= response.json().items()


@pytest.mark.anyio
async def test_get_comments_on_post(
    async_client: AsyncClient, created_post: dict, created_comment: dict
):
    response = await async_client.get(f"/post/{created_post['id']}/comment")

    assert response.status_code == 200
    assert response.json() == [created_comment]


@pytest.mark.anyio
async def test_get_comments_on_post_empty(
    async_client: AsyncClient, created_post: dict
):
    response = await async_client.get(f"/post/{created_post['id']}/comment")

    assert response.status_code == 200
    assert response.json() == []


@pytest.mark.anyio
async def test_get_post_with_comments(
    async_client: AsyncClient, created_post: dict, created_comment: dict
):
    response = await async_client.get(f"/post/{created_post['id']}")

    assert response.status_code == 200
    assert response.json() == {"post": created_post, "comments": [created_comment]}


@pytest.mark.anyio
async def test_get_missing_post_with_comments(
    async_client: AsyncClient, created_post: dict, created_comment: dict
):
    response = await async_client.get("/post/2")
    assert response.status_code == 404

Modified files

storeapi/config.py
--- 
+++ 
@@ -35,7 +35,7 @@


 @lru_cache()
-def get_config(env_state):
+def get_config(env_state: str):
     """Instantiate config based on the environment."""
     configs = {"dev": DevConfig, "prod": ProdConfig, "test": TestConfig}
     return configs[env_state]()
storeapi/database.py
--- 
+++ 
@@ -1,5 +1,6 @@
 import databases
 import sqlalchemy
+
 from storeapi.config import config

 metadata = sqlalchemy.MetaData()
@@ -8,20 +9,21 @@
     "posts",
     metadata,
     sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
-    sqlalchemy.Column("body", sqlalchemy.String),
+    sqlalchemy.Column("body", sqlalchemy.String)
 )

-comments_table = sqlalchemy.Table(
+comment_table = sqlalchemy.Table(
     "comments",
     metadata,
     sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
     sqlalchemy.Column("body", sqlalchemy.String),
-    sqlalchemy.Column("post_id", sqlalchemy.ForeignKey("posts.id"), nullable=False),
+    sqlalchemy.Column("post_id", sqlalchemy.ForeignKey("posts.id"), nullable=False)
 )

 engine = sqlalchemy.create_engine(
     config.DATABASE_URL, connect_args={"check_same_thread": False}
 )
+
 metadata.create_all(engine)
 database = databases.Database(
     config.DATABASE_URL, force_rollback=config.DB_FORCE_ROLL_BACK
storeapi/main.py
--- 
+++ 
@@ -1,8 +1,9 @@
 from contextlib import asynccontextmanager

 from fastapi import FastAPI
+
 from storeapi.database import database
-from storeapi.routers.posts import router as posts_router
+from storeapi.routers.post import router as post_router


 @asynccontextmanager
@@ -12,4 +13,6 @@
     await database.disconnect()

 app = FastAPI(lifespan=lifespan)
-app.include_router(posts_router)
+
+
+app.include_router(post_router)
storeapi/tests/conftest.py
--- 
+++ 
@@ -4,9 +4,9 @@
 import pytest
 from fastapi.testclient import TestClient
 from httpx import AsyncClient
-from storeapi.routers.posts import comments_table, post_table

 os.environ["ENV_STATE"] = "test"
+from storeapi.database import database  # noqa: E402
 from storeapi.main import app  # noqa: E402


@@ -22,9 +22,9 @@

 @pytest.fixture(autouse=True)
 async def db() -> AsyncGenerator:
-    post_table.clear()
-    comments_table.clear()
+    await database.connect()
     yield
+    await database.disconnect()


 @pytest.fixture()
storeapi/models/post.py
--- 
+++ 
@@ -1,4 +1,4 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict


 class UserPostIn(BaseModel):
@@ -6,6 +6,8 @@


 class UserPost(UserPostIn):
+    model_config = ConfigDict(from_attributes=True)
+
     id: int


@@ -15,6 +17,8 @@


 class Comment(CommentIn):
+    model_config = ConfigDict(from_attributes=True)
+
     id: int