diff --git a/database.py b/database.py index 1500269..40fdc3e 100644 --- a/database.py +++ b/database.py @@ -6,9 +6,17 @@ import jwt import datetime import security -connection = sqlite3.connect('database.db') -connection.row_factory = sqlite3.Row -cursor = connection.cursor() + +def connect() -> (sqlite3.Connection, sqlite3.Cursor): + """Connects to the database and returns the connection and cursor.""" + connection = sqlite3.connect('database.db') + connection.row_factory = sqlite3.Row + cursor = connection.cursor() + + return connection, cursor + + +connection, cursor = connect() def init() -> None: @@ -46,31 +54,29 @@ def register(user: User) -> None: cursor.execute( "INSERT INTO users (name, password) VALUES (?, ?)", (user.name, - security.hash_password(user.password)) + security.hash_password(user.password))) connection.commit() def get_user_by_token(request: Request) -> User: """Retrieves a user from the database using a JWT token.""" - token=request.headers.get("Authorization") + token = request.headers.get("Authorization") if not token or not token.startswith("Bearer "): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated" ) - token=token.split(" ")[1] - payload=jwt.decode( + token = token.split(" ")[1] + payload = jwt.decode( token, key=settings.jwt_secret, algorithms=[ settings.jwt_algorithm]) - connection=sqlite3.connect('database.db') - connection.row_factory=sqlite3.Row - cursor=connection.cursor() + connection, cursor = connect() cursor.execute( "SELECT id, name, password FROM users WHERE id = ?", (payload["id"],)) - row=cursor.fetchone() + row = cursor.fetchone() connection.close() if not row: @@ -86,7 +92,7 @@ def login(user: User) -> str: """Logs in a user and returns a JWT token.""" cursor.execute( "SELECT id, name, password FROM users WHERE name = ?", (user.name,)) - row=cursor.fetchone() + row = cursor.fetchone() if not row or not security.verify_password(user.password, row["password"]): raise HTTPException( @@ -94,9 +100,9 @@ def login(user: User) -> str: detail="Invalid credentials" ) - exp=datetime.datetime.now( + exp = datetime.datetime.now( datetime.timezone.utc) + datetime.timedelta(hours=1) - payload={ + payload = { "id": row["id"], "exp": exp }