Image Generation with Background Tasks

Executing image generation in our FastAPI endpoint

Want more?

This lesson for enrolled students only. Join the course to unlock it!

You can see the code changes implemented in this lecture below.

If you have purchased the course in a different platform, you still have access to the code changes per lecture here on Teclado. The lecture video and lecture notes remain locked.
Join course for $30

New files

storeapi/tests/routers/conftest.py
import pytest


@pytest.fixture()
def mock_generate_cute_creature_api(mocker):
    return mocker.patch(
        "storeapi.tasks._generate_cute_creature_api",
        return_value={"output_url": "http://example.net"},
    )

Modified files

storeapi/routers/post.py
--- 
+++ 
@@ -3,7 +3,7 @@
 from typing import Annotated

 import sqlalchemy
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
 from storeapi.database import comment_table, database, like_table, post_table
 from storeapi.models.post import (
     Comment,
@@ -17,6 +17,7 @@
 )
 from storeapi.models.user import User
 from storeapi.security import get_current_user
+from storeapi.tasks import generate_and_add_to_post

 router = APIRouter()

@@ -41,7 +42,11 @@

 @router.post("/post", response_model=UserPost, status_code=201)
 async def create_post(
-    post: UserPostIn, current_user: Annotated[User, Depends(get_current_user)]
+    post: UserPostIn,
+    current_user: Annotated[User, Depends(get_current_user)],
+    background_tasks: BackgroundTasks,
+    request: Request,
+    prompt: str = None,
 ):
     logger.info("Creating post")

@@ -51,6 +56,15 @@
     logger.debug(query)

     last_record_id = await database.execute(query)
+    if prompt:
+        background_tasks.add_task(
+            generate_and_add_to_post,
+            current_user.email,
+            last_record_id,
+            request.url_for("get_post_with_comments", post_id=last_record_id),
+            database,
+            prompt,
+        )
     return {**data, "id": last_record_id}
storeapi/tests/routers/test_post.py
--- 
+++ 
@@ -109,6 +109,24 @@
         headers={"Authorization": f"Bearer {logged_in_token}"},
     )
     assert response.status_code == 201
+
+
+@pytest.mark.anyio
+async def test_create_post_with_prompt(
+    async_client: AsyncClient, logged_in_token: str, mock_generate_cute_creature_api
+):
+    response = await async_client.post(
+        "/post?prompt=A cat",
+        json={"body": "Test Post"},
+        headers={"Authorization": f"Bearer {logged_in_token}"},
+    )
+    assert response.status_code == 201
+    assert {
+        "id": 1,
+        "body": "Test Post",
+        "image_url": None,
+    }.items() <= response.json().items()
+    mock_generate_cute_creature_api.assert_called()


 @pytest.mark.anyio