Compare commits

..

7 Commits

7 changed files with 186 additions and 44 deletions

View File

@@ -1,48 +1,104 @@
from models import User from models import User, TokenResponse
from passlib.context import CryptContext
from settings import settings from settings import settings
from fastapi import HTTPException, status, Request
import sqlite3 import sqlite3
import jwt import security
import datetime
connection = sqlite3.connect('database.db')
connection.row_factory = sqlite3.Row
cursor = connection.cursor()
password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto") def connect() -> (sqlite3.Connection, sqlite3.Cursor):
"""Connects to the database and returns the connection and cursor."""
connection = sqlite3.connect('database.db')
connection.row_factory = sqlite3.Row
cursor = connection.cursor()
return connection, cursor
connection, cursor = connect()
def init() -> None: def init() -> None:
# create users table """Initializes the database."""
cursor.execute('''
CREATE TABLE IF NOT EXISTS users ( # Create users table
id INTEGER PRIMARY KEY, cursor.execute('''
name TEXT NOT NULL UNIQUE, CREATE TABLE IF NOT EXISTS users (
password TEXT NOT NULL id INTEGER PRIMARY KEY,
) name TEXT NOT NULL UNIQUE,
''') password TEXT NOT NULL
)
''')
# Create logs table
cursor.execute('''
CREATE TABLE IF NOT EXISTS logs (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
calories DOUBLE NOT NULL,
description TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id)
)
''')
def close() -> None: def close() -> None:
connection.close() """Closes the database connection."""
connection.close()
def register(user: User) -> None: def register(user: User) -> None:
password = password_context.hash(user.password) """Registers a new user in the database."""
cursor.execute("INSERT INTO users (name, password) VALUES (?, ?)", (user.name, password)) try:
connection.commit() cursor.execute(
"INSERT INTO users (name, password) VALUES (?, ?)",
(user.name,
security.hash_password(user.password)))
connection.commit()
except sqlite3.IntegrityError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User already exists"
)
def get_user_by_token(request: Request) -> User:
"""Retrieves a user from the database using a JWT token."""
payload = security.decode_jwt(
request.headers.get("Authorization"))
if not payload:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated"
)
connection, cursor = connect()
cursor.execute(
"SELECT id, name, password FROM users WHERE id = ?", (payload["id"],))
row = cursor.fetchone()
connection.close()
if not row:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated"
)
return User(**row)
def login(user: User) -> str: def login(user: User) -> str:
cursor.execute("SELECT id, name, password FROM users WHERE name = ?", (user.name,)) """Logs in a user and returns a JWT token."""
row = cursor.fetchone() cursor.execute(
"SELECT id, name, password FROM users WHERE name = ?", (user.name,))
if not row: row = cursor.fetchone()
raise Exception('User not found')
if not password_context.verify(user.password, row["password"]): if not row or not security.verify_password(user.password, row["password"]):
raise Exception('Invalid password') raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
exp = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) detail="Invalid credentials"
payload = { )
"id": row["id"],
"exp": exp return TokenResponse(token=security.sign_jwt(row))
}
return jwt.encode(payload=payload, key=settings.jwt_secret, algorithm=settings.jwt_algorithm)

6
format.sh Executable file
View File

@@ -0,0 +1,6 @@
#!/usr/bin/env bash
which autopep8 &> /dev/null || { echo "autopep8 not found, please install it."; exit 1; }
autopep8 --in-place --aggressive --aggressive --recursive --exclude .venv,.git,__pycache__ .
echo "Code formatted with autopep8."

22
main.py
View File

@@ -1,7 +1,15 @@
from fastapi import FastAPI from fastapi import FastAPI, Depends, status
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import database import database
import models import models
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -12,15 +20,17 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
@app.get("/") @app.get("/me")
async def root(): async def me(user: models.User = Depends(database.get_user_by_token)):
return {"message": "Hello World"} return user
@app.post("/users")
@app.post("/register", status_code=status.HTTP_201_CREATED)
async def register(user: models.User): async def register(user: models.User):
database.register(user) database.register(user)
return user return user
@app.post("/login") @app.post("/login")
async def login(user: models.User): async def login(user: models.User):
return database.login(user) return database.login(user)

View File

@@ -1,5 +1,10 @@
from pydantic import BaseModel from pydantic import BaseModel
class User(BaseModel): class User(BaseModel):
name: str name: str
password: str password: str
class TokenResponse(BaseModel):
token: str

View File

@@ -2,3 +2,4 @@ fastapi[standard]==0.121.2
passlib==1.7.4 passlib==1.7.4
pyjwt==2.10.1 pyjwt==2.10.1
pydantic-settings==2.12.0 pydantic-settings==2.12.0
autopep8==2.3.2

61
security.py Normal file
View File

@@ -0,0 +1,61 @@
from passlib.context import CryptContext
from settings import settings
import jwt
import datetime
import logging
password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto")
logger = logging.getLogger(__name__)
def hash_password(password: str) -> str:
"""Hashes a plain text password."""
return password_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies a plain text password against a hashed password."""
return password_context.verify(plain_password, hashed_password)
def sign_jwt(row: dict) -> str:
"""Signs a JWT token with the given payload."""
exp = datetime.datetime.now(
datetime.timezone.utc) + datetime.timedelta(hours=1)
payload = {
"id": row["id"],
"exp": exp
}
return jwt.encode(
payload,
key=settings.jwt_secret,
algorithm=settings.jwt_algorithm
)
def decode_jwt(token: str | None) -> dict | None:
"""Decodes a JWT token and returns the payload."""
if not token or not token.startswith("Bearer "):
logger.warning(
"No token provided or token does not start with 'Bearer '")
return None
try:
payload = jwt.decode(
token.replace("Bearer ", ""),
key=settings.jwt_secret,
algorithms=[settings.jwt_algorithm]
)
return payload
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
return None
except jwt.InvalidTokenError:
logger.warning("Invalid token")
return None
logger.warning("Unexpected error in token decoding")
return None

View File

@@ -1,9 +1,12 @@
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8') model_config = SettingsConfigDict(
env_file='.env', env_file_encoding='utf-8')
jwt_secret: str jwt_secret: str
jwt_algorithm: str jwt_algorithm: str
settings = Settings() settings = Settings()