extract password hashing into separate module

This commit is contained in:
2025-11-16 16:42:20 +01:00
parent 16bf9b54f2
commit a120512baf
2 changed files with 51 additions and 28 deletions

View File

@@ -1,62 +1,76 @@
from models import User from models import User
from passlib.context import CryptContext
from settings import settings from settings import settings
from fastapi import HTTPException, status, Request from fastapi import HTTPException, status, Request
import sqlite3 import sqlite3
import jwt import jwt
import datetime import datetime
import security
connection = sqlite3.connect('database.db') connection = sqlite3.connect('database.db')
connection.row_factory = sqlite3.Row connection.row_factory = sqlite3.Row
cursor = connection.cursor() cursor = connection.cursor()
password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto")
def init() -> None: def init() -> None:
# create users table """Initializes the database."""
# Create users table
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE, name TEXT NOT NULL UNIQUE,
password TEXT NOT NULL password TEXT NOT NULL
) )
''') ''')
# Create logs table
cursor.execute('''
CREATE TABLE IF NOT EXISTS logs (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
calories DOUBLE NOT NULL,
description TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id)
)
''')
def close() -> None: def close() -> None:
"""Closes the database connection."""
connection.close() connection.close()
def register(user: User) -> None: def register(user: User) -> None:
password = password_context.hash(user.password) """Registers a new user in the database."""
cursor.execute( cursor.execute(
"INSERT INTO users (name, password) VALUES (?, ?)", "INSERT INTO users (name, password) VALUES (?, ?)",
(user.name, (user.name,
password)) security.hash_password(user.password))
connection.commit() connection.commit()
def get_user_by_token(request: Request) -> User: def get_user_by_token(request: Request) -> User:
token = request.headers.get("Authorization") """Retrieves a user from the database using a JWT token."""
token=request.headers.get("Authorization")
if not token or not token.startswith("Bearer "): if not token or not token.startswith("Bearer "):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated" detail="Not authenticated"
) )
token = token.split(" ")[1] token=token.split(" ")[1]
payload = jwt.decode( payload=jwt.decode(
token, token,
key=settings.jwt_secret, key=settings.jwt_secret,
algorithms=[ algorithms=[
settings.jwt_algorithm]) settings.jwt_algorithm])
connection = sqlite3.connect('database.db') connection=sqlite3.connect('database.db')
connection.row_factory = sqlite3.Row connection.row_factory=sqlite3.Row
cursor = connection.cursor() cursor=connection.cursor()
cursor.execute( cursor.execute(
"SELECT id, name, password FROM users WHERE id = ?", (payload["id"],)) "SELECT id, name, password FROM users WHERE id = ?", (payload["id"],))
row = cursor.fetchone() row=cursor.fetchone()
connection.close() connection.close()
if not row: if not row:
@@ -69,24 +83,20 @@ def get_user_by_token(request: Request) -> User:
def login(user: User) -> str: def login(user: User) -> str:
"""Logs in a user and returns a JWT token."""
cursor.execute( cursor.execute(
"SELECT id, name, password FROM users WHERE name = ?", (user.name,)) "SELECT id, name, password FROM users WHERE name = ?", (user.name,))
row = cursor.fetchone() row=cursor.fetchone()
if not row: if not row or not security.verify_password(user.password, row["password"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials"
)
if not password_context.verify(user.password, row["password"]):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials" detail="Invalid credentials"
) )
exp = datetime.datetime.now( exp=datetime.datetime.now(
datetime.timezone.utc) + datetime.timedelta(hours=1) datetime.timezone.utc) + datetime.timedelta(hours=1)
payload = { payload={
"id": row["id"], "id": row["id"],
"exp": exp "exp": exp
} }

13
security.py Normal file
View File

@@ -0,0 +1,13 @@
from passlib.context import CryptContext
password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto")
def hash_password(password: str) -> str:
"""Hashes a plain text password."""
return password_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies a plain text password against a hashed password."""
return password_context.verify(plain_password, hashed_password)