diff --git a/database.py b/database.py index 40fdc3e..9de338a 100644 --- a/database.py +++ b/database.py @@ -2,8 +2,6 @@ from models import User from settings import settings from fastapi import HTTPException, status, Request import sqlite3 -import jwt -import datetime import security @@ -60,18 +58,15 @@ def register(user: User) -> None: def get_user_by_token(request: Request) -> User: """Retrieves a user from the database using a JWT token.""" - token = request.headers.get("Authorization") - if not token or not token.startswith("Bearer "): + + payload = security.decode_jwt( + request.headers.get("Authorization")) + + if not payload: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated" ) - token = token.split(" ")[1] - payload = jwt.decode( - token, - key=settings.jwt_secret, - algorithms=[ - settings.jwt_algorithm]) connection, cursor = connect() cursor.execute( @@ -100,14 +95,4 @@ def login(user: User) -> str: detail="Invalid credentials" ) - exp = datetime.datetime.now( - datetime.timezone.utc) + datetime.timedelta(hours=1) - payload = { - "id": row["id"], - "exp": exp - } - - return jwt.encode( - payload=payload, - key=settings.jwt_secret, - algorithm=settings.jwt_algorithm) + return security.sign_jwt(row) diff --git a/main.py b/main.py index 83828cc..695da3a 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,13 @@ from fastapi import FastAPI, Depends from contextlib import asynccontextmanager import database import models +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) @asynccontextmanager diff --git a/security.py b/security.py index 159d748..c35f79a 100644 --- a/security.py +++ b/security.py @@ -1,6 +1,11 @@ from passlib.context import CryptContext +from settings import settings +import jwt +import datetime +import logging password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto") +logger = logging.getLogger(__name__) def hash_password(password: str) -> str: @@ -11,3 +16,46 @@ def hash_password(password: str) -> str: def verify_password(plain_password: str, hashed_password: str) -> bool: """Verifies a plain text password against a hashed password.""" return password_context.verify(plain_password, hashed_password) + + +def sign_jwt(row: dict) -> str: + """Signs a JWT token with the given payload.""" + exp = datetime.datetime.now( + datetime.timezone.utc) + datetime.timedelta(hours=1) + + payload = { + "id": row["id"], + "exp": exp + } + + return jwt.encode( + payload, + key=settings.jwt_secret, + algorithm=settings.jwt_algorithm + ) + + +def decode_jwt(token: str | None) -> dict | None: + """Decodes a JWT token and returns the payload.""" + if not token or not token.startswith("Bearer "): + logger.warning( + "No token provided or token does not start with 'Bearer '") + return None + + try: + payload = jwt.decode( + token.replace("Bearer ", ""), + key=settings.jwt_secret, + algorithms=[settings.jwt_algorithm] + ) + + return payload + except jwt.ExpiredSignatureError: + logger.warning("Token has expired") + return None + except jwt.InvalidTokenError: + logger.warning("Invalid token") + return None + + logger.warning("Unexpected error in token decoding") + return None