Compare commits

..

3 Commits

Author SHA1 Message Date
87c9ae9030 add token response 2025-11-16 17:03:35 +01:00
2c57c9cdc9 handle already existing users 2025-11-16 17:00:40 +01:00
2d751551b7 add logging to security 2025-11-16 16:59:35 +01:00
4 changed files with 79 additions and 29 deletions

View File

@@ -1,9 +1,7 @@
from models import User from models import User, TokenResponse
from settings import settings from settings import settings
from fastapi import HTTPException, status, Request from fastapi import HTTPException, status, Request
import sqlite3 import sqlite3
import jwt
import datetime
import security import security
@@ -51,27 +49,30 @@ def close() -> None:
def register(user: User) -> None: def register(user: User) -> None:
"""Registers a new user in the database.""" """Registers a new user in the database."""
try:
cursor.execute( cursor.execute(
"INSERT INTO users (name, password) VALUES (?, ?)", "INSERT INTO users (name, password) VALUES (?, ?)",
(user.name, (user.name,
security.hash_password(user.password))) security.hash_password(user.password)))
connection.commit() connection.commit()
except sqlite3.IntegrityError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User already exists"
)
def get_user_by_token(request: Request) -> User: def get_user_by_token(request: Request) -> User:
"""Retrieves a user from the database using a JWT token.""" """Retrieves a user from the database using a JWT token."""
token = request.headers.get("Authorization")
if not token or not token.startswith("Bearer "): payload = security.decode_jwt(
request.headers.get("Authorization"))
if not payload:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated" detail="Not authenticated"
) )
token = token.split(" ")[1]
payload = jwt.decode(
token,
key=settings.jwt_secret,
algorithms=[
settings.jwt_algorithm])
connection, cursor = connect() connection, cursor = connect()
cursor.execute( cursor.execute(
@@ -100,14 +101,4 @@ def login(user: User) -> str:
detail="Invalid credentials" detail="Invalid credentials"
) )
exp = datetime.datetime.now( return TokenResponse(token=security.sign_jwt(row))
datetime.timezone.utc) + datetime.timedelta(hours=1)
payload = {
"id": row["id"],
"exp": exp
}
return jwt.encode(
payload=payload,
key=settings.jwt_secret,
algorithm=settings.jwt_algorithm)

11
main.py
View File

@@ -1,7 +1,14 @@
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends, status
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import database import database
import models import models
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
@asynccontextmanager @asynccontextmanager
@@ -18,7 +25,7 @@ async def me(user: models.User = Depends(database.get_user_by_token)):
return user return user
@app.post("/users") @app.post("/register", status_code=status.HTTP_201_CREATED)
async def register(user: models.User): async def register(user: models.User):
database.register(user) database.register(user)
return user return user

View File

@@ -4,3 +4,7 @@ from pydantic import BaseModel
class User(BaseModel): class User(BaseModel):
name: str name: str
password: str password: str
class TokenResponse(BaseModel):
token: str

View File

@@ -1,6 +1,11 @@
from passlib.context import CryptContext from passlib.context import CryptContext
from settings import settings
import jwt
import datetime
import logging
password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto") password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto")
logger = logging.getLogger(__name__)
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
@@ -11,3 +16,46 @@ def hash_password(password: str) -> str:
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies a plain text password against a hashed password.""" """Verifies a plain text password against a hashed password."""
return password_context.verify(plain_password, hashed_password) return password_context.verify(plain_password, hashed_password)
def sign_jwt(row: dict) -> str:
"""Signs a JWT token with the given payload."""
exp = datetime.datetime.now(
datetime.timezone.utc) + datetime.timedelta(hours=1)
payload = {
"id": row["id"],
"exp": exp
}
return jwt.encode(
payload,
key=settings.jwt_secret,
algorithm=settings.jwt_algorithm
)
def decode_jwt(token: str | None) -> dict | None:
"""Decodes a JWT token and returns the payload."""
if not token or not token.startswith("Bearer "):
logger.warning(
"No token provided or token does not start with 'Bearer '")
return None
try:
payload = jwt.decode(
token.replace("Bearer ", ""),
key=settings.jwt_secret,
algorithms=[settings.jwt_algorithm]
)
return payload
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
return None
except jwt.InvalidTokenError:
logger.warning("Invalid token")
return None
logger.warning("Unexpected error in token decoding")
return None