New files
storeapi/tests/routers/conftest.py
import pytest
@pytest.fixture()
def mock_generate_cute_creature_api(mocker):
return mocker.patch(
"storeapi.tasks._generate_cute_creature_api",
return_value={"output_url": "http://example.net"},
)
Modified files
storeapi/routers/post.py
---
+++
@@ -3,7 +3,7 @@
from typing import Annotated
import sqlalchemy
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from storeapi.database import comment_table, database, like_table, post_table
from storeapi.models.post import (
Comment,
@@ -17,6 +17,7 @@
)
from storeapi.models.user import User
from storeapi.security import get_current_user
+from storeapi.tasks import generate_and_add_to_post
router = APIRouter()
@@ -41,7 +42,11 @@
@router.post("/post", response_model=UserPost, status_code=201)
async def create_post(
- post: UserPostIn, current_user: Annotated[User, Depends(get_current_user)]
+ post: UserPostIn,
+ current_user: Annotated[User, Depends(get_current_user)],
+ background_tasks: BackgroundTasks,
+ request: Request,
+ prompt: str = None,
):
logger.info("Creating post")
@@ -51,6 +56,15 @@
logger.debug(query)
last_record_id = await database.execute(query)
+ if prompt:
+ background_tasks.add_task(
+ generate_and_add_to_post,
+ current_user.email,
+ last_record_id,
+ request.url_for("get_post_with_comments", post_id=last_record_id),
+ database,
+ prompt,
+ )
return {**data, "id": last_record_id}
storeapi/tests/routers/test_post.py
---
+++
@@ -109,6 +109,24 @@
headers={"Authorization": f"Bearer {logged_in_token}"},
)
assert response.status_code == 201
+
+
+@pytest.mark.anyio
+async def test_create_post_with_prompt(
+ async_client: AsyncClient, logged_in_token: str, mock_generate_cute_creature_api
+):
+ response = await async_client.post(
+ "/post?prompt=A cat",
+ json={"body": "Test Post"},
+ headers={"Authorization": f"Bearer {logged_in_token}"},
+ )
+ assert response.status_code == 201
+ assert {
+ "id": 1,
+ "body": "Test Post",
+ "image_url": None,
+ }.items() <= response.json().items()
+ mock_generate_cute_creature_api.assert_called()
@pytest.mark.anyio