Modified files
storeapi/database.py
---
+++
@@ -9,7 +9,8 @@
"posts",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
- sqlalchemy.Column("body", sqlalchemy.String)
+ sqlalchemy.Column("body", sqlalchemy.String),
+ sqlalchemy.Column("user_id", sqlalchemy.ForeignKey("users.id"), nullable=False),
)
user_table = sqlalchemy.Table(
@@ -20,12 +21,14 @@
sqlalchemy.Column("password", sqlalchemy.String),
)
+
comment_table = sqlalchemy.Table(
"comments",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("body", sqlalchemy.String),
- sqlalchemy.Column("post_id", sqlalchemy.ForeignKey("posts.id"), nullable=False)
+ sqlalchemy.Column("post_id", sqlalchemy.ForeignKey("posts.id"), nullable=False),
+ sqlalchemy.Column("user_id", sqlalchemy.ForeignKey("users.id"), nullable=False),
)
engine = sqlalchemy.create_engine(
storeapi/security.py
---
+++
@@ -1,25 +1,25 @@
import datetime
+import logging
from typing import Annotated
-import logging
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import ExpiredSignatureError, JWTError, jwt
from passlib.context import CryptContext
+
from storeapi.database import database, user_table
logger = logging.getLogger(__name__)
SECRET_KEY = "9b73f2a1bdd7ae163444473d29a6885ffa22ab26117068f72a5a56a74d12d1fc"
ALGORITHM = "HS256"
-
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"])
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
- headers={"WWW-Authenticate": "Bearer"}, # Tells client to authenticate with JWT
+ headers={"WWW-Authenticate": "Bearer"},
)
@@ -29,11 +29,11 @@
def create_access_token(email: str):
logger.debug("Creating access token", extra={"email": email})
- expire = datetime.datetime.utcnow() + datetime.timedelta(
+ expire = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
minutes=access_token_expire_minutes()
)
jwt_data = {"sub": email, "exp": expire}
- encoded_jwt = jwt.encode(jwt_data, SECRET_KEY, algorithm=ALGORITHM)
+ encoded_jwt = jwt.encode(jwt_data, key=SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@@ -65,7 +65,7 @@
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
try:
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+ payload = jwt.decode(token, key=SECRET_KEY, algorithms=[ALGORITHM])
email = payload.get("sub")
if email is None:
raise credentials_exception
storeapi/logging_conf.py
---
+++
@@ -1,11 +1,10 @@
import logging
from logging.config import dictConfig
-from storeapi.config import DevConfig, config
+from storeapi.config import DevConfig, ProdConfig, config
-def obfuscated(email: str, obfuscated_length: int):
- """Obfuscate email address for logging purposes."""
+def obfuscated(email: str, obfuscated_length: int) -> str:
characters = email[:obfuscated_length]
first, last = email.split("@")
return characters + ("*" * (len(first) - obfuscated_length)) + "@" + last
@@ -23,7 +22,7 @@
handlers = ["default", "rotating_file"]
-if config.ENV_STATE == "prod":
+if isinstance(config, ProdConfig):
handlers = ["default", "rotating_file", "logtail"]
@@ -52,46 +51,43 @@
"file": {
"class": "pythonjsonlogger.jsonlogger.JsonFormatter",
"datefmt": "%Y-%m-%dT%H:%M:%S",
- # For JsonFormatter, the format string just defines what keys are included in the log record
- # It's a bit clunky, but it's the way to do it for now
"format": "%(asctime)s %(msecs)03d %(levelname)s %(correlation_id)s %(name)s %(lineno)d %(message)s",
},
},
"handlers": {
"default": {
- "class": "rich.logging.RichHandler", # could use logging.StreamHandler instead
+ "class": "rich.logging.RichHandler",
"level": "DEBUG",
"formatter": "console",
- "filters": ["correlation_id", "email_obfuscation"],
- },
- "logtail": {
- # https://betterstack.com/docs/logs/python/
- "class": "logtail.LogtailHandler",
- "level": "DEBUG",
- "formatter": "console",
- "filters": ["correlation_id", "email_obfuscation"],
- "source_token": config.LOGTAIL_API_KEY, # gets passed to LogtailHandler constructor as kwargs
+ "filters": ["correlation_id", "email_obfuscation"]
},
"rotating_file": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "file",
+ "filename": "storeapi.log",
+ "maxBytes": 1024 * 1024, # 1MB
+ "backupCount": 5,
+ "encoding": "utf8",
+ "filters": ["correlation_id", "email_obfuscation"]
+ },
+ "logtail": {
+ "class": "logtail.LogtailHandler",
+ "level": "DEBUG",
+ "formatter": "console",
"filters": ["correlation_id", "email_obfuscation"],
- "filename": "storeapi.log",
- "maxBytes": 1024 * 1024, # 1 MB
- "backupCount": 2,
- "encoding": "utf8",
- },
+ "source_token": config.LOGTAIL_API_KEY
+ }
},
"loggers": {
"uvicorn": {"handlers": ["default", "rotating_file"], "level": "INFO"},
"storeapi": {
"handlers": handlers,
"level": "DEBUG" if isinstance(config, DevConfig) else "INFO",
- "propagate": False,
+ "propagate": False
},
"databases": {"handlers": ["default"], "level": "WARNING"},
- "aiosqlite": {"handlers": ["default"], "level": "WARNING"},
- },
+ "aiosqlite": {"handlers": ["default"], "level": "WARNING"}
+ }
}
- )
+ )
storeapi/main.py
---
+++
@@ -4,6 +4,7 @@
from asgi_correlation_id import CorrelationIdMiddleware
from fastapi import FastAPI, HTTPException
from fastapi.exception_handlers import http_exception_handler
+
from storeapi.database import database
from storeapi.logging_conf import configure_logging
from storeapi.routers.post import router as post_router
@@ -20,13 +21,12 @@
app = FastAPI(lifespan=lifespan)
app.add_middleware(CorrelationIdMiddleware)
+
+app.include_router(post_router)
app.include_router(user_router)
-app.include_router(post_router)
@app.exception_handler(HTTPException)
async def http_exception_handle_logging(request, exc):
logger.error(f"HTTPException: {exc.status_code} {exc.detail}")
- return await http_exception_handler(request, exc)
-
-
+ return await http_exception_handler(request, exc)
storeapi/routers/post.py
---
+++
@@ -30,12 +30,11 @@
@router.post("/post", response_model=UserPost, status_code=201)
async def create_post(
- post: UserPostIn,
- current_user: Annotated[User, Depends(get_current_user)],
+ post: UserPostIn, current_user: Annotated[User, Depends(get_current_user)]
):
logger.info("Creating post")
- data = post.model_dump() # previously .dict()
+ data = {**post.model_dump(), "user_id": current_user.id}
query = post_table.insert().values(data)
logger.debug(query)
@@ -62,11 +61,10 @@
logger.info("Creating comment")
post = await find_post(comment.post_id)
-
if not post:
raise HTTPException(status_code=404, detail="Post not found")
- data = comment.model_dump() # previously .dict()
+ data = {**comment.model_dump(), "user_id": current_user.id}
query = comment_table.insert().values(data)
logger.debug(query)
@@ -91,7 +89,6 @@
logger.info("Getting post and its comments")
post = await find_post(post_id)
-
if not post:
raise HTTPException(status_code=404, detail="Post not found")
storeapi/routers/user.py
---
+++
@@ -1,6 +1,7 @@
import logging
from fastapi import APIRouter, HTTPException, status
+
from storeapi.database import database, user_table
from storeapi.models.user import UserIn
from storeapi.security import (
storeapi/tests/test_security.py
---
+++
@@ -1,5 +1,6 @@
import pytest
from jose import jwt
+
from storeapi import security
@@ -10,7 +11,7 @@
def test_create_access_token():
token = security.create_access_token("123")
assert {"sub": "123"}.items() <= jwt.decode(
- token, security.SECRET_KEY, algorithms=[security.ALGORITHM]
+ token, key=security.SECRET_KEY, algorithms=[security.ALGORITHM]
).items()
@@ -22,6 +23,7 @@
@pytest.mark.anyio
async def test_get_user(registered_user: dict):
user = await security.get_user(registered_user["email"])
+
assert user.email == registered_user["email"]
@@ -42,7 +44,7 @@
@pytest.mark.anyio
async def test_authenticate_user_not_found():
with pytest.raises(security.HTTPException):
- await security.authenticate_user("test@example.com", "1234")
+ await security.authenticate_user("test@example.net", "1234")
@pytest.mark.anyio
storeapi/tests/conftest.py
---
+++
@@ -34,7 +34,7 @@
@pytest.fixture()
-async def registered_user(async_client: AsyncClient):
+async def registered_user(async_client: AsyncClient) -> dict:
user_details = {"email": "test@example.net", "password": "1234"}
await async_client.post("/register", json=user_details)
query = user_table.select().where(user_table.c.email == user_details["email"])
@@ -44,6 +44,6 @@
@pytest.fixture()
-async def logged_in_token(async_client: AsyncClient, registered_user: dict):
+async def logged_in_token(async_client: AsyncClient, registered_user: dict) -> str:
response = await async_client.post("/token", json=registered_user)
return response.json()["access_token"]
storeapi/tests/routers/test_user.py
---
+++
@@ -16,7 +16,7 @@
@pytest.mark.anyio
-async def test_register_already_exists(
+async def test_register_user_already_exists(
async_client: AsyncClient, registered_user: dict
):
response = await register_user(
storeapi/tests/routers/test_post.py
---
+++
@@ -1,5 +1,6 @@
import pytest
from httpx import AsyncClient
+
from storeapi import security
@@ -40,7 +41,9 @@
@pytest.mark.anyio
-async def test_create_post(async_client: AsyncClient, logged_in_token: str):
+async def test_create_post(
+ async_client: AsyncClient, registered_user: dict, logged_in_token: str
+):
body = "Test Post"
response = await async_client.post(
@@ -48,8 +51,13 @@
json={"body": body},
headers={"Authorization": f"Bearer {logged_in_token}"},
)
+
assert response.status_code == 201
- assert {"id": 1, "body": body}.items() <= response.json().items()
+ assert {
+ "id": 1,
+ "body": body,
+ "user_id": registered_user["id"],
+ }.items() <= response.json().items()
@pytest.mark.anyio
@@ -63,6 +71,7 @@
json={"body": "Test Post"},
headers={"Authorization": f"Bearer {token}"},
)
+
assert response.status_code == 401
assert "Token has expired" in response.json()["detail"]
@@ -72,10 +81,9 @@
async_client: AsyncClient, logged_in_token: str
):
response = await async_client.post(
- "/post",
- json={},
- headers={"Authorization": f"Bearer {logged_in_token}"},
+ "/post", json={}, headers={"Authorization": f"Bearer {logged_in_token}"}
)
+
assert response.status_code == 422
@@ -91,6 +99,7 @@
async def test_create_comment(
async_client: AsyncClient,
created_post: dict,
+ registered_user: dict,
logged_in_token: str,
):
body = "Test Comment"
@@ -105,6 +114,7 @@
"id": 1,
"body": body,
"post_id": created_post["id"],
+ "user_id": registered_user["id"],
}.items() <= response.json().items()
storeapi/models/post.py
---
+++
@@ -9,6 +9,7 @@
model_config = ConfigDict(from_attributes=True)
id: int
+ user_id: int
class CommentIn(BaseModel):
@@ -20,6 +21,7 @@
model_config = ConfigDict(from_attributes=True)
id: int
+ user_id: int
class UserPostWithComments(BaseModel):