diff --git a/database.py b/database.py index c1c0a48..e4906fc 100644 --- a/database.py +++ b/database.py @@ -12,69 +12,86 @@ cursor = connection.cursor() password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto") + def init() -> None: # create users table - cursor.execute(''' + cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE, password TEXT NOT NULL ) ''') - + + def close() -> None: - connection.close() - + connection.close() + def register(user: User) -> None: - password = password_context.hash(user.password) - cursor.execute("INSERT INTO users (name, password) VALUES (?, ?)", (user.name, password)) - connection.commit() + password = password_context.hash(user.password) + cursor.execute( + "INSERT INTO users (name, password) VALUES (?, ?)", + (user.name, + password)) + connection.commit() + def get_user_by_token(request: Request) -> User: token = request.headers.get("Authorization") if not token or not token.startswith("Bearer "): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated" - ) - token = token.split(" ")[1] - payload = jwt.decode(token, key=settings.jwt_secret, algorithms=[settings.jwt_algorithm]) - + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated" + ) + token = token.split(" ")[1] + payload = jwt.decode( + token, + key=settings.jwt_secret, + algorithms=[ + settings.jwt_algorithm]) + connection = sqlite3.connect('database.db') connection.row_factory = sqlite3.Row cursor = connection.cursor() - cursor.execute("SELECT id, name, password FROM users WHERE id = ?", (payload["id"],)) + cursor.execute( + "SELECT id, name, password FROM users WHERE id = ?", (payload["id"],)) row = cursor.fetchone() connection.close() if not row: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated" - ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated" + ) return User(**row) + def login(user: User) -> str: - cursor.execute("SELECT id, name, password FROM users WHERE name = ?", (user.name,)) - row = cursor.fetchone() - - if not row: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials" - ) - if not password_context.verify(user.password, row["password"]): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials" - ) - - exp = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) - payload = { - "id": row["id"], - "exp": exp - } - - return jwt.encode(payload=payload, key=settings.jwt_secret, algorithm=settings.jwt_algorithm) + cursor.execute( + "SELECT id, name, password FROM users WHERE name = ?", (user.name,)) + row = cursor.fetchone() + + if not row: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials" + ) + if not password_context.verify(user.password, row["password"]): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials" + ) + + exp = datetime.datetime.now( + datetime.timezone.utc) + datetime.timedelta(hours=1) + payload = { + "id": row["id"], + "exp": exp + } + + return jwt.encode( + payload=payload, + key=settings.jwt_secret, + algorithm=settings.jwt_algorithm) diff --git a/format.sh b/format.sh new file mode 100755 index 0000000..7f2ab18 --- /dev/null +++ b/format.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +which autopep8 &> /dev/null || { echo "autopep8 not found, please install it."; exit 1; } + +autopep8 --in-place --aggressive --aggressive --recursive --exclude .venv,.git,__pycache__ . + +echo "Code formatted with autopep8." diff --git a/main.py b/main.py index 3651f04..83828cc 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ from contextlib import asynccontextmanager import database import models + @asynccontextmanager async def lifespan(app: FastAPI): database.init() @@ -12,19 +13,17 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) -@app.get("/") -async def root(): - return {"message": "Hello World"} - @app.get("/me") async def me(user: models.User = Depends(database.get_user_by_token)): return user + @app.post("/users") async def register(user: models.User): database.register(user) return user + @app.post("/login") async def login(user: models.User): - return database.login(user) \ No newline at end of file + return database.login(user) diff --git a/models.py b/models.py index 8f458c7..bc275ee 100644 --- a/models.py +++ b/models.py @@ -1,5 +1,6 @@ from pydantic import BaseModel + class User(BaseModel): name: str password: str diff --git a/requirements.txt b/requirements.txt index e1a0537..c18d1da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ fastapi[standard]==0.121.2 passlib==1.7.4 pyjwt==2.10.1 pydantic-settings==2.12.0 +autopep8==2.3.2 diff --git a/settings.py b/settings.py index b5b957f..6cb5f5a 100644 --- a/settings.py +++ b/settings.py @@ -1,9 +1,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict + class Settings(BaseSettings): - model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8') + model_config = SettingsConfigDict( + env_file='.env', env_file_encoding='utf-8') jwt_secret: str jwt_algorithm: str + settings = Settings()