Database Operations
This guide covers common database operations, patterns, and best practices when working with SQLModel and Alembic.
tip
For comprehensive documentation, see the SQLModel and SQLAlchemy documentation.
Adding Relationships
One-to-Many Relationship
Example: Tasks belong to Categories
# In db_models.py
# ModelWithIdAndTimestamps is a parent class that add an id (UUID), create_at/updated_at/deleted_at (TIMESTAMP)
class CategoryBase(ModelWithIdAndTimestamps):
name: str = Field(max_length=100, unique=True)
color: str = Field(max_length=7, default="#000000")
# A DB table should extend SQLModel and have this `table=True` paramter
class Category(CategoryBase, table=True):
tasks: list[Task] = Relationship(back_populates="category") # OneToMany
class Task(TaskBase, table=True):
user: User = Relationship(back_populates="tasks")
category_id: UUID | None = Field(foreign_key="category.id", default=None, ondelete="CASCADE") # This is the foreign key in DB
category: Category | None = Relationship(back_populates="tasks") # ManyToOne
Database Queries
To get an AsyncSession
(a DB connection from the pool), you have 2 options:
- In a FastAPI route using the FastAPI dependency
DatabaseDep
:@router.get("/user/{user_id}", response_model=UserDTO)
async def get_user(user_id: str, db: DatabaseDep):
query = select(User).where(User.id == user_id)
user = (await db.exec(query)).first()
return UserDTO.model_validate(**user) - In any other part of your app using the
DatabaseService.get_session()
:@abraxas.activity(display_name="Step 1: Get user") # type: ignore[attr-defined]
async def get_user(params: GetUserParams) -> GetUserResult:
async with DatabaseService.get_session() as session:
user = await session.exec(select(User).where(User.id == params.id))
return GetUserResult(user=user.first())
Joins and Eager Loading
# Load tasks with their category and user
async def get_tasks_with_details() -> list[Task]:
async with DatabaseService.get_session() as session:
result = await session.exec(
select(Task)
.options(
joinedload(Task.category),
joinedload(Task.user)
)
)
return result.unique().all()
Aggregations and Statistics
# Count tasks by status
async def get_task_stats(user_id: UUID) -> dict:
async with DatabaseService.get_session() as session:
total_result = await session.exec(
select(func.count(Task.id)).where(Task.user_id == user_id)
)
completed_result = await session.exec(
select(func.count(Task.id)).where(
and_(Task.user_id == user_id, Task.completed == True)
)
)
total = total_result.first() or 0
completed = completed_result.first() or 0
return {
"total": total,
"completed": completed,
"pending": total - completed,
"completion_rate": (completed / total * 100) if total > 0 else 0
}
# Group tasks by category
async def get_tasks_by_category(user_id: UUID) -> dict:
async with DatabaseService.get_session() as session:
result = await session.exec(
select(Category.name, func.count(Task.id).label("task_count"))
.join(Task)
.where(Task.user_id == user_id)
.group_by(Category.id, Category.name)
)
return {row[0]: row[1] for row in result.all()}
Database Transactions
Handling Complex Operations
from sqlmodel import Session
from sqlalchemy.exc import IntegrityError
async def create_task_with_category(
task_data: CreateTaskDTO,
category_name: str,
user_id: UUID
) -> TaskDTO:
async with DatabaseService.get_session() as session:
try:
# Check if category exists
category_result = await session.exec(
select(Category).where(Category.name == category_name)
)
category = category_result.first()
# Create category if it doesn't exist
if not category:
category = Category(name=category_name)
session.add(category)
await session.flush() # Get the ID without committing
# Create task
task = Task.model_validate(task_data)
task.user_id = user_id
task.category_id = category.id
session.add(task) # Attach the task object to the current instance
await session.commit() # Commit the transaction (apply the modifications made on all attached objects)
await session.refresh(task) # Refresh the task object with the values from the DB
return TaskDTO.model_validate(task)
except Exception:
await session.rollback()
raise HTTPException(status_code=400, detail="Failed to create task")
Bulk Operations
# Bulk insert
async def create_multiple_tasks(
tasks_data: list[CreateTaskDTO],
user_id: UUID
) -> list[TaskDTO]:
async with DatabaseService.get_session() as session:
tasks = []
for task_data in tasks_data:
task = Task.model_validate(task_data)
task.user_id = user_id
tasks.append(task)
session.add(task)
await session.commit()
# Refresh all tasks
for task in tasks:
await session.refresh(task)
return [TaskDTO.model_validate(task) for task in tasks]
# Bulk update
async def mark_all_completed(user_id: UUID) -> int:
async with DatabaseService.get_session() as session:
result = await session.exec(
select(Task).where(
and_(Task.user_id == user_id, Task.completed == False)
)
)
tasks = result.all()
for task in tasks:
task.completed = True
await session.commit()
return len(tasks)
Performance Tips
Database Indexing
# Add indexes in your models
class Task(TaskBase, table=True):
user_id: UUID = Field(foreign_key="user.id", index=True) # Already indexed
completed: bool = Field(default=False, index=True) # Add index for filtering
created_at: datetime = Field(index=True) # Add index for sorting
For more complexe indexes (multi-column or other types of indexes than B-Tree), use a migration script with Alembics
Query Optimization
# Good: Use select_related equivalent (joinedload)
async def get_tasks_with_users() -> list[Task]:
async with DatabaseService.get_session() as session:
result = await session.exec(
select(Task).options(joinedload(Task.user))
)
return result.unique().all()
# Bad: N+1 queries
async def get_tasks_with_users_bad() -> list[dict]:
async with DatabaseService.get_session() as session:
result = await session.exec(select(Task))
tasks = result.all()
# This creates N additional queries!
task_data = []
for task in tasks:
user_result = await session.exec(select(User).where(User.id == task.user_id))
user = user_result.first()
task_data.append({"task": task, "user": user})
return task_data
Pagination
async def get_paginated_tasks(
user_id: UUID,
page: int = 1,
size: int = 20
) -> dict:
async with DatabaseService.get_session() as session:
offset = (page - 1) * size
# Get total count
count_result = await session.exec(
select(func.count(Task.id)).where(Task.user_id == user_id)
)
total = count_result.first() or 0
# Get paginated results
result = await session.exec(
select(Task)
.where(Task.user_id == user_id)
.order_by(Task.created_at.desc())
.offset(offset)
.limit(size)
)
tasks = result.all()
return {
"tasks": [TaskDTO.model_validate(task) for task in tasks],
"pagination": {
"page": page,
"size": size,
"total": total,
"pages": (total + size - 1) // size
}
}
Troubleshooting
Common Issues
Foreign Key Errors:
# Ensure foreign key relationships are properly defined
class Task(TaskBase, table=True):
user_id: UUID = Field(foreign_key="user.id") # Table name, not class name
Query Performance:
# Use EXPLAIN to analyze query performance
query = select(Task).where(Task.user_id == user_id)
print(f"Query: {query}")
# For PostgreSQL, you can check the query plan
result = await db.exec(text(f"EXPLAIN ANALYZE {query}"))