Skip to main content

Database Operations

This guide covers common database operations, patterns, and best practices when working with SQLModel and Alembic.

tip

For comprehensive documentation, see the SQLModel and SQLAlchemy documentation.

Adding Relationships

One-to-Many Relationship

Example: Tasks belong to Categories

# In db_models.py
# ModelWithIdAndTimestamps is a parent class that add an id (UUID), create_at/updated_at/deleted_at (TIMESTAMP)
class CategoryBase(ModelWithIdAndTimestamps):
name: str = Field(max_length=100, unique=True)
color: str = Field(max_length=7, default="#000000")

# A DB table should extend SQLModel and have this `table=True` paramter
class Category(CategoryBase, table=True):
tasks: list[Task] = Relationship(back_populates="category") # OneToMany

class Task(TaskBase, table=True):
user: User = Relationship(back_populates="tasks")
category_id: UUID | None = Field(foreign_key="category.id", default=None, ondelete="CASCADE") # This is the foreign key in DB
category: Category | None = Relationship(back_populates="tasks") # ManyToOne

Database Queries

To get an AsyncSession (a DB connection from the pool), you have 2 options:

  • In a FastAPI route using the FastAPI dependency DatabaseDep:
    @router.get("/user/{user_id}", response_model=UserDTO)
    async def get_user(user_id: str, db: DatabaseDep):
    query = select(User).where(User.id == user_id)
    user = (await db.exec(query)).first()
    return UserDTO.model_validate(**user)
  • In any other part of your app using the DatabaseService.get_session():
    @abraxas.activity(display_name="Step 1: Get user")  # type: ignore[attr-defined]
    async def get_user(params: GetUserParams) -> GetUserResult:
    async with DatabaseService.get_session() as session:
    user = await session.exec(select(User).where(User.id == params.id))
    return GetUserResult(user=user.first())

Joins and Eager Loading

# Load tasks with their category and user
async def get_tasks_with_details() -> list[Task]:
async with DatabaseService.get_session() as session:
result = await session.exec(
select(Task)
.options(
joinedload(Task.category),
joinedload(Task.user)
)
)
return result.unique().all()

Aggregations and Statistics

# Count tasks by status
async def get_task_stats(user_id: UUID) -> dict:
async with DatabaseService.get_session() as session:
total_result = await session.exec(
select(func.count(Task.id)).where(Task.user_id == user_id)
)

completed_result = await session.exec(
select(func.count(Task.id)).where(
and_(Task.user_id == user_id, Task.completed == True)
)
)

total = total_result.first() or 0
completed = completed_result.first() or 0

return {
"total": total,
"completed": completed,
"pending": total - completed,
"completion_rate": (completed / total * 100) if total > 0 else 0
}

# Group tasks by category
async def get_tasks_by_category(user_id: UUID) -> dict:
async with DatabaseService.get_session() as session:
result = await session.exec(
select(Category.name, func.count(Task.id).label("task_count"))
.join(Task)
.where(Task.user_id == user_id)
.group_by(Category.id, Category.name)
)

return {row[0]: row[1] for row in result.all()}

Database Transactions

Handling Complex Operations

from sqlmodel import Session
from sqlalchemy.exc import IntegrityError

async def create_task_with_category(
task_data: CreateTaskDTO,
category_name: str,
user_id: UUID
) -> TaskDTO:
async with DatabaseService.get_session() as session:
try:
# Check if category exists
category_result = await session.exec(
select(Category).where(Category.name == category_name)
)
category = category_result.first()

# Create category if it doesn't exist
if not category:
category = Category(name=category_name)
session.add(category)
await session.flush() # Get the ID without committing

# Create task
task = Task.model_validate(task_data)
task.user_id = user_id
task.category_id = category.id

session.add(task) # Attach the task object to the current instance
await session.commit() # Commit the transaction (apply the modifications made on all attached objects)
await session.refresh(task) # Refresh the task object with the values from the DB

return TaskDTO.model_validate(task)

except Exception:
await session.rollback()
raise HTTPException(status_code=400, detail="Failed to create task")

Bulk Operations

# Bulk insert
async def create_multiple_tasks(
tasks_data: list[CreateTaskDTO],
user_id: UUID
) -> list[TaskDTO]:
async with DatabaseService.get_session() as session:
tasks = []
for task_data in tasks_data:
task = Task.model_validate(task_data)
task.user_id = user_id
tasks.append(task)
session.add(task)

await session.commit()

# Refresh all tasks
for task in tasks:
await session.refresh(task)

return [TaskDTO.model_validate(task) for task in tasks]

# Bulk update
async def mark_all_completed(user_id: UUID) -> int:
async with DatabaseService.get_session() as session:
result = await session.exec(
select(Task).where(
and_(Task.user_id == user_id, Task.completed == False)
)
)
tasks = result.all()

for task in tasks:
task.completed = True

await session.commit()
return len(tasks)

Performance Tips

Database Indexing

# Add indexes in your models
class Task(TaskBase, table=True):
user_id: UUID = Field(foreign_key="user.id", index=True) # Already indexed
completed: bool = Field(default=False, index=True) # Add index for filtering
created_at: datetime = Field(index=True) # Add index for sorting

For more complexe indexes (multi-column or other types of indexes than B-Tree), use a migration script with Alembics

Query Optimization

# Good: Use select_related equivalent (joinedload)
async def get_tasks_with_users() -> list[Task]:
async with DatabaseService.get_session() as session:
result = await session.exec(
select(Task).options(joinedload(Task.user))
)
return result.unique().all()

# Bad: N+1 queries
async def get_tasks_with_users_bad() -> list[dict]:
async with DatabaseService.get_session() as session:
result = await session.exec(select(Task))
tasks = result.all()

# This creates N additional queries!
task_data = []
for task in tasks:
user_result = await session.exec(select(User).where(User.id == task.user_id))
user = user_result.first()
task_data.append({"task": task, "user": user})

return task_data

Pagination

async def get_paginated_tasks(
user_id: UUID,
page: int = 1,
size: int = 20
) -> dict:
async with DatabaseService.get_session() as session:
offset = (page - 1) * size

# Get total count
count_result = await session.exec(
select(func.count(Task.id)).where(Task.user_id == user_id)
)
total = count_result.first() or 0

# Get paginated results
result = await session.exec(
select(Task)
.where(Task.user_id == user_id)
.order_by(Task.created_at.desc())
.offset(offset)
.limit(size)
)
tasks = result.all()

return {
"tasks": [TaskDTO.model_validate(task) for task in tasks],
"pagination": {
"page": page,
"size": size,
"total": total,
"pages": (total + size - 1) // size
}
}

Troubleshooting

Common Issues

Foreign Key Errors:

# Ensure foreign key relationships are properly defined
class Task(TaskBase, table=True):
user_id: UUID = Field(foreign_key="user.id") # Table name, not class name

Query Performance:

# Use EXPLAIN to analyze query performance
query = select(Task).where(Task.user_id == user_id)
print(f"Query: {query}")

# For PostgreSQL, you can check the query plan
result = await db.exec(text(f"EXPLAIN ANALYZE {query}"))