Compare commits

...

7 Commits

7 changed files with 186 additions and 44 deletions

View File

@@ -1,48 +1,104 @@
from models import User
from passlib.context import CryptContext
from models import User, TokenResponse
from settings import settings
from fastapi import HTTPException, status, Request
import sqlite3
import jwt
import datetime
import security
connection = sqlite3.connect('database.db')
connection.row_factory = sqlite3.Row
cursor = connection.cursor()
password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto")
def connect() -> (sqlite3.Connection, sqlite3.Cursor):
"""Connects to the database and returns the connection and cursor."""
connection = sqlite3.connect('database.db')
connection.row_factory = sqlite3.Row
cursor = connection.cursor()
return connection, cursor
connection, cursor = connect()
def init() -> None:
# create users table
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
password TEXT NOT NULL
)
''')
"""Initializes the database."""
# Create users table
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
password TEXT NOT NULL
)
''')
# Create logs table
cursor.execute('''
CREATE TABLE IF NOT EXISTS logs (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
calories DOUBLE NOT NULL,
description TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id)
)
''')
def close() -> None:
connection.close()
"""Closes the database connection."""
connection.close()
def register(user: User) -> None:
password = password_context.hash(user.password)
cursor.execute("INSERT INTO users (name, password) VALUES (?, ?)", (user.name, password))
connection.commit()
"""Registers a new user in the database."""
try:
cursor.execute(
"INSERT INTO users (name, password) VALUES (?, ?)",
(user.name,
security.hash_password(user.password)))
connection.commit()
except sqlite3.IntegrityError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User already exists"
)
def get_user_by_token(request: Request) -> User:
"""Retrieves a user from the database using a JWT token."""
payload = security.decode_jwt(
request.headers.get("Authorization"))
if not payload:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated"
)
connection, cursor = connect()
cursor.execute(
"SELECT id, name, password FROM users WHERE id = ?", (payload["id"],))
row = cursor.fetchone()
connection.close()
if not row:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated"
)
return User(**row)
def login(user: User) -> str:
cursor.execute("SELECT id, name, password FROM users WHERE name = ?", (user.name,))
row = cursor.fetchone()
if not row:
raise Exception('User not found')
if not password_context.verify(user.password, row["password"]):
raise Exception('Invalid password')
exp = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)
payload = {
"id": row["id"],
"exp": exp
}
return jwt.encode(payload=payload, key=settings.jwt_secret, algorithm=settings.jwt_algorithm)
"""Logs in a user and returns a JWT token."""
cursor.execute(
"SELECT id, name, password FROM users WHERE name = ?", (user.name,))
row = cursor.fetchone()
if not row or not security.verify_password(user.password, row["password"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials"
)
return TokenResponse(token=security.sign_jwt(row))

6
format.sh Executable file
View File

@@ -0,0 +1,6 @@
#!/usr/bin/env bash
which autopep8 &> /dev/null || { echo "autopep8 not found, please install it."; exit 1; }
autopep8 --in-place --aggressive --aggressive --recursive --exclude .venv,.git,__pycache__ .
echo "Code formatted with autopep8."

22
main.py
View File

@@ -1,7 +1,15 @@
from fastapi import FastAPI
from fastapi import FastAPI, Depends, status
from contextlib import asynccontextmanager
import database
import models
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -12,15 +20,17 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/me")
async def me(user: models.User = Depends(database.get_user_by_token)):
return user
@app.post("/users")
@app.post("/register", status_code=status.HTTP_201_CREATED)
async def register(user: models.User):
database.register(user)
return user
@app.post("/login")
async def login(user: models.User):
return database.login(user)
return database.login(user)

View File

@@ -1,5 +1,10 @@
from pydantic import BaseModel
class User(BaseModel):
name: str
password: str
class TokenResponse(BaseModel):
token: str

View File

@@ -2,3 +2,4 @@ fastapi[standard]==0.121.2
passlib==1.7.4
pyjwt==2.10.1
pydantic-settings==2.12.0
autopep8==2.3.2

61
security.py Normal file
View File

@@ -0,0 +1,61 @@
from passlib.context import CryptContext
from settings import settings
import jwt
import datetime
import logging
password_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto")
logger = logging.getLogger(__name__)
def hash_password(password: str) -> str:
"""Hashes a plain text password."""
return password_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies a plain text password against a hashed password."""
return password_context.verify(plain_password, hashed_password)
def sign_jwt(row: dict) -> str:
"""Signs a JWT token with the given payload."""
exp = datetime.datetime.now(
datetime.timezone.utc) + datetime.timedelta(hours=1)
payload = {
"id": row["id"],
"exp": exp
}
return jwt.encode(
payload,
key=settings.jwt_secret,
algorithm=settings.jwt_algorithm
)
def decode_jwt(token: str | None) -> dict | None:
"""Decodes a JWT token and returns the payload."""
if not token or not token.startswith("Bearer "):
logger.warning(
"No token provided or token does not start with 'Bearer '")
return None
try:
payload = jwt.decode(
token.replace("Bearer ", ""),
key=settings.jwt_secret,
algorithms=[settings.jwt_algorithm]
)
return payload
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
return None
except jwt.InvalidTokenError:
logger.warning("Invalid token")
return None
logger.warning("Unexpected error in token decoding")
return None

View File

@@ -1,9 +1,12 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8')
model_config = SettingsConfigDict(
env_file='.env', env_file_encoding='utf-8')
jwt_secret: str
jwt_algorithm: str
settings = Settings()