Modified files
storeapi/security.py
---
+++
@@ -1,12 +1,11 @@
import datetime
import logging
-from typing import Annotated
+from typing import Annotated, Literal
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import ExpiredSignatureError, JWTError, jwt
from passlib.context import CryptContext
-
from storeapi.database import database, user_table
logger = logging.getLogger(__name__)
@@ -16,11 +15,13 @@
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"])
-credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Could not validate credentials",
- headers={"WWW-Authenticate": "Bearer"},
-)
+
+def create_unauthorized_exception(detail: str) -> HTTPException:
+ return HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=detail,
+ headers={"WWW-Authenticate": "Bearer"},
+ )
def access_token_expire_minutes() -> int:
@@ -51,6 +52,29 @@
return encoded_jwt
+def get_subject_for_token_type(
+ token: str, type: Literal["access", "confirmation"]
+) -> str:
+ try:
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+ except ExpiredSignatureError as e:
+ raise create_unauthorized_exception("Token has expired") from e
+ except JWTError as e:
+ raise create_unauthorized_exception("Invalid token") from e
+
+ email = payload.get("sub")
+ if email is None:
+ raise create_unauthorized_exception("Token is missing 'sub' field")
+
+ token_type = payload.get("type")
+ if token_type is None or token_type != type:
+ raise create_unauthorized_exception(
+ f"Token has incorrect type, expected '{type}'"
+ )
+
+ return email
+
+
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
@@ -71,31 +95,15 @@
logger.debug("Authenticating user", extra={"email": email})
user = await get_user(email)
if not user:
- raise credentials_exception
+ raise create_unauthorized_exception("Invalid email or password")
if not verify_password(password, user.password):
- raise credentials_exception
+ raise create_unauthorized_exception("Invalid email or password")
return user
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
- try:
- payload = jwt.decode(token, key=SECRET_KEY, algorithms=[ALGORITHM])
- email = payload.get("sub")
- if email is None:
- raise credentials_exception
-
- type = payload.get("type")
- if type is None or type != "access":
- raise credentials_exception
- except ExpiredSignatureError as e:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Token has expired",
- headers={"WWW-Authenticate": "Bearer"},
- ) from e
- except JWTError as e:
- raise credentials_exception from e
+ email = get_subject_for_token_type(token, "access")
user = await get_user(email=email)
if user is None:
- raise credentials_exception
+ raise create_unauthorized_exception("Could not find user for this token")
return user
storeapi/tests/test_security.py
---
+++
@@ -1,6 +1,7 @@
+import time
+
import pytest
from jose import jwt
-
from storeapi import security
@@ -24,6 +25,56 @@
assert {"sub": "123", "type": "confirmation"}.items() <= jwt.decode(
token, key=security.SECRET_KEY, algorithms=[security.ALGORITHM]
).items()
+
+
+def test_get_subject_for_token_type_valid_confirmation():
+ email = "test@example.com"
+ token = security.create_confirmation_token(email)
+ assert email == security.get_subject_for_token_type(token, "confirmation")
+
+
+def test_get_subject_for_token_type_valid_access():
+ email = "test@example.com"
+ token = security.create_access_token(email)
+ assert email == security.get_subject_for_token_type(token, "access")
+
+
+def test_get_subject_for_token_type_expired(mocker):
+ mocker.patch("storeapi.security.access_token_expire_minutes", return_value=-1)
+ email = "test@example.com"
+ token = security.create_access_token(email)
+ with pytest.raises(security.HTTPException) as exc_info:
+ security.get_subject_for_token_type(token, "access")
+ assert "Token has expired" == exc_info.value.detail
+
+
+def test_get_subject_for_token_type_invalid_token():
+ token = "invalid token"
+ with pytest.raises(security.HTTPException) as exc_info:
+ security.get_subject_for_token_type(token, "access")
+ assert "Invalid token" == exc_info.value.detail
+
+
+def test_get_subject_for_token_type_missing_sub():
+ email = "test@example.com"
+ token = security.create_access_token(email)
+ payload = jwt.decode(
+ token, key=security.SECRET_KEY, algorithms=[security.ALGORITHM]
+ )
+ del payload["sub"]
+ token = jwt.encode(payload, key=security.SECRET_KEY, algorithm=security.ALGORITHM)
+
+ with pytest.raises(security.HTTPException) as exc_info:
+ security.get_subject_for_token_type(token, "access")
+ assert "Token is missing 'sub' field" == exc_info.value.detail
+
+
+def test_get_subject_for_token_type_wrong_type():
+ email = "test@example.com"
+ token = security.create_confirmation_token(email)
+ with pytest.raises(security.HTTPException) as exc_info:
+ security.get_subject_for_token_type(token, "access")
+ assert "Token has incorrect type, expected 'access'" == exc_info.value.detail
def test_password_hashes():
@@ -81,4 +132,4 @@
async def test_get_current_user_wrong_type_token(registered_user: dict):
token = security.create_confirmation_token(registered_user["email"])
with pytest.raises(security.HTTPException):
- await security.get_current_user(token)+ await security.get_current_user(token)