From f5741ecc62a860114f35fb93cb55128d27168f8b Mon Sep 17 00:00:00 2001 From: Nemanja Latkovic Date: Sun, 16 Nov 2025 15:27:51 +0100 Subject: [PATCH] added authorization --- database.py | 36 ++++++++++++++++++++++++++++++++++-- main.py | 6 +++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/database.py b/database.py index d3837b5..c1c0a48 100644 --- a/database.py +++ b/database.py @@ -1,6 +1,7 @@ from models import User from passlib.context import CryptContext from settings import settings +from fastapi import HTTPException, status, Request import sqlite3 import jwt import datetime @@ -30,14 +31,45 @@ def register(user: User) -> None: cursor.execute("INSERT INTO users (name, password) VALUES (?, ?)", (user.name, password)) connection.commit() +def get_user_by_token(request: Request) -> User: + token = request.headers.get("Authorization") + if not token or not token.startswith("Bearer "): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated" + ) + token = token.split(" ")[1] + payload = jwt.decode(token, key=settings.jwt_secret, algorithms=[settings.jwt_algorithm]) + + connection = sqlite3.connect('database.db') + connection.row_factory = sqlite3.Row + cursor = connection.cursor() + cursor.execute("SELECT id, name, password FROM users WHERE id = ?", (payload["id"],)) + row = cursor.fetchone() + connection.close() + + if not row: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated" + ) + + return User(**row) + def login(user: User) -> str: cursor.execute("SELECT id, name, password FROM users WHERE name = ?", (user.name,)) row = cursor.fetchone() if not row: - raise Exception('User not found') + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials" + ) if not password_context.verify(user.password, row["password"]): - raise Exception('Invalid password') + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials" + ) exp = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) payload = { diff --git a/main.py b/main.py index 70c4d17..3651f04 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI +from fastapi import FastAPI, Depends from contextlib import asynccontextmanager import database import models @@ -16,6 +16,10 @@ app = FastAPI(lifespan=lifespan) async def root(): return {"message": "Hello World"} +@app.get("/me") +async def me(user: models.User = Depends(database.get_user_by_token)): + return user + @app.post("/users") async def register(user: models.User): database.register(user)