Practical Dataloading with Strawberry GraphQL and Aioinject
15 min read
Table of Contents
Dataloaders are a necessary abstraction for efficient data fetching in GraphQL. This blog post explores how to practically implement data loading in Python, alongside integration with Aioinject, an async Python dependency injection library.
The N+1 Problem
GraphQL provides immense data loading capabilities to clients. But with great power, comes great responsibilities. It is easy to define queries like the one below, which can impact database performance if the resolvers under the hood are implemented naively.
Assuming the following GraphQL schema:
type Job {
id: ID!
title: String!
description: String!
}
type Organization {
id: ID!
name: String!
jobs: [Job]
}
type Query {
organizations: [Organization]
}
We can query the schema like:
query OrganizationsQuery {
organizations {
id,
name,
jobs {
title
description
}
}
}
A naive resolver implementation will fetch the the jobs separately from the database on each organization returned, for the query. Let’s say it finds N organizations in the database. For each organization found, the jobs() resolver will be invoked to locate all the jobs associated with that organization. This resolver will trigger a database call for each organization, which will be N calls. This means that in total, there will be N+1 database calls occurring.
This is not very efficient and won’t scale after a point.
Dataloaders to the rescue
Here's where dataloaders come to save the day! Dataloaders can batch similar client GraphQL requests into a single query. Basically, consider them as an intermediate layer that converts multiple similar requests to a single batched request. Its main tasks are:
Batching: Instead of making a separate database or API call for every individual query from different resolvers, the DataLoader collects these requests and groups (or “batches”) them together. This way, a single query can retrieve multiple records at once.
Caching: It temporarily stores (caches) results from data fetches during a single request cycle. If the same data is needed again, it can be returned from the cache without making another external call.
How It Works Step-by-Step
Request Collection: When a GraphQL resolver requests data (for example, album information), it calls the DataLoader’s load(key) function. The DataLoader collects these requests for the current event loop tick.
Batching and Execution: At the end of the tick, the DataLoader calls a single batch function (often named batchLoad) with all the collected keys. This function fetches all the requested data in one go, such as via a single database query.
Caching Results: Once the batch function returns the data, the DataLoader caches the results. Any subsequent request for the same key within the same cycle will retrieve the data from the cache, avoiding redundant work.
Resolver Return: Finally, the GraphQL resolver receives the data from the DataLoader (either directly from the cache or from the fresh batch) and returns it to the client.
You can read more on dataloaders here.
Initializing dataloaders
We need to share the dataloaders across all resolvers for a single request, so that the data can be batch-loaded and cached in the context of a single request.
Ideally, dataloaders will be initialized once per GraphQL request context. This can be done by defining the dataloaders as an attribute on the Context
class in Strawberry GraphQL, or by utilizing a dependency injection library that caches the dataloaders for the scope of a single GraphQL request.
You can read more on this here.
A naive attempt at data loading
Let us implement a simple data loader, using a pseudo session instance to query an imaginary database, without any dependency injection first.
from typing import List
from strawberry.dataloader import DataLoader
import strawberry
from app.database import session
@strawberry.type
class User:
id: strawberry.ID
async def load_users_by_id(keys) -> List[User]:
async with session() as session:
users = await session.find_many(ids=keys)
users_by_id = {
user.id: user for user in users
}
return [users_by_id.get(key) for key in keys]
loader = DataLoader(load_fn=load_users_by_id)
We can use this dataloader in a resolver like this:
import strawberry
from app.dataloaders import loader
@strawberry.type
class Query:
@strawberry.field
async def get_user(self, id: strawberry.ID) -> User:
return await loader.load(id)
schema = strawberry.Schema(query=Query)
This is a very simple implementation, but has a few drawbacks:
Harder Lifecycle Management: Managing the lifecycle of dependencies, including creation, reuse, and disposal is hard and error-prone.
Tight Coupling: The loader directly imports and uses a specific database session instance, making it hard to swap out or update the session for different environments or needs.
Testing Challenges: It's difficult to mock or stub dependencies such as the database session. This makes testing slower, more complex, and less isolated.
Reduced Flexibility: Changing the data source (e.g., switching to a cache or microservice) would require modifying the code rather than simply configuring a new dependency.
Potential Global State Issues: Manually managing dependencies might inadvertently introduce global state or singleton patterns, which can lead to unexpected side effects, especially in asynchronous or multi-threaded environments.
Integrating data loading with Aioinject
Let's integrate Aioinject to manage dependencies alongside our dataloaders.
Setup Guide
- Step 1: Install Aioinject
pip install -U aioinject
- Step 2: Create an Aioinject container
import aioinject
container = aioinject.Container()
Iteration 1- Directly injecting into the load function
Let us declare our dataloader load function, assuming that we have an account repository named AccountRepo
that handles the data fetching for us. Let us also assume that the AccountRepo
has been registered as a dependency in the container, with the appropriate scope.
Internally, the AccountRepo can use any kind of abstraction (like sessions) to access the database- which is out of the scope of this tutorial.
from typing import Annotated
from aioinject import Inject
from aioinject.ext.strawberry import inject
from app.repositories import AccountRepo
@inject
async def load_account_by_id(
account_ids: list[str],
account_repo: Annotated[AccountRepo, Inject],
) -> list[Account | None]:
"""Load multiple accounts by their IDs."""
id_to_account_map = {
account.id: account
for account in await account_repo.get_many_by_ids(account_ids)
}
return [id_to_account_map.get(account_id) for account_id in account_ids]
Here, we are decorating the load function with the @inject
decorator, and directly passing in dependencies as arguments.
Next, we need to create a custom context class to place the dataloader in.
from typing import TypedDict
from fastapi import BackgroundTasks, Request, Response
from strawberry.dataloader import DataLoader
from strawberry.types import Info as StrawberryInfo
from app.models import Account
@dataclasses.dataclass(slots=True, kw_only=True)
class Dataloaders:
account_by_id: DataLoader[str, Account | None]
class BaseContext(TypedDict):
request: Request
response: Response
background_tasks: BackgroundTasks
loaders: Dataloaders
We'll also create a function to initialize a new set of dataloaders for each request.
from typing import TypedDict
from fastapi import BackgroundTasks, Request, Response
from strawberry.dataloader import DataLoader
from strawberry.types import Info as StrawberryInfo
from app.models import Account
from app.dataloaders import load_account_by_id
@dataclasses.dataclass(slots=True, kw_only=True)
class Dataloaders:
account_by_id: DataLoader[str, Account | None]
class BaseContext(TypedDict):
request: Request
response: Response
background_tasks: BackgroundTasks
loaders: Dataloaders
def create_dataloaders() -> Dataloaders:
return Dataloaders(
account_by_id=DataLoader(
load_fn=load_account_by_id, # type: ignore[arg-type]
),
)
Now, we need to declare a custom GraphQL context getter that calls this intialization function.
from typing import Annotated
from fastapi import BackgroundTasks, Request, Response
from strawberry.fastapi import GraphQLRouter
from .context import create_dataloaders, Context
from .schema import schema
async def get_context(
request: Request,
response: Response,
background_tasks: BackgroundTasks,
) -> Context:
return Context(
request=request,
response=response,
background_tasks=background_tasks,
loaders=create_dataloaders(),
)
def create_graphql_router() -> GraphQLRouter:
return GraphQLRouter(
schema=schema,
context_getter=get_context,
graphql_ide="graphiql",
)
Now, we can mount the GraphQL router (with the specified context getter) onto our FastAPI application.
✅ Benefits
- Easy and straight-forward implementation.
- We are leveraging the powers of the DI framework to load resources efficiently.
❌ Limitations
- Static type checkers are not happy- the load function is supposed to take in only one argument, keys.
Iteration 2- Dependency providers to create dataloaders per request
We can implement dataloading in a better way, by creating dependency providers that create dataloaders per request.
from strawberry.dataloader import DataLoader
from app.accounts.repositories import AccountRepo
from .models import Account
async def create_account_by_id_dataloader(
account_repo: AccountRepo,
) -> DataLoader[str, Account | None]:
"""Create a dataloader to load accounts by their IDs."""
async def load_fn(account_ids: list[str]) -> list[Account | None]:
"""Load multiple entities by their keys."""
id_to_account_map = {
account.id: account
for account in await account_repo.get_many_by_ids(account_ids)
}
return [id_to_account_map.get(account_id) for account_id in account_ids]
return Dataloader(load_fn=load_fn)
Next, we need to add this dependency provider to our Aioinject container.
import aioinject
from app.dependencies import create_account_by_id_dataloader
container = aioinject.Container()
container.register(aioinject.Scoped(create_account_by_id_dataloader))
We also need to create a custom context class, like we did in the previous iteration.
from typing import TypedDict
from fastapi import BackgroundTasks, Request, Response
from strawberry.dataloader import DataLoader
from strawberry.types import Info as StrawberryInfo
from app.models import Account
@dataclasses.dataclass(slots=True, kw_only=True)
class Dataloaders:
account_by_id: DataLoader[str, Account | None]
class BaseContext(TypedDict):
request: Request
response: Response
background_tasks: BackgroundTasks
loaders: Dataloaders
But this time, we create another dependency provider that creates the dataloaders per request for us.
from strawberry.dataloader import DataLoader
from app.accounts.repositories import AccountRepo
from app.context import Dataloaders
from .models import Account
async def create_account_by_id_dataloader(
account_repo: AccountRepo,
) -> DataLoader[str, Account | None]:
"""Create a dataloader to load accounts by their IDs."""
async def load_fn(account_ids: list[str]) -> list[Account | None]:
"""Load multiple entities by their keys."""
id_to_account_map = {
account.id: account
for account in await account_repo.get_many_by_ids(account_ids)
}
return [id_to_account_map.get(account_id) for account_id in account_ids]
return Dataloader(load_fn=load_fn)
def create_dataloaders(
account_by_id: DataLoader[str, Account | None],
) -> Dataloaders:
"""Create dataloaders for the current context."""
return Dataloaders(
account_by_id=account_by_id,
)
We need to add this to the container too.
import aioinject
from app.dependencies import create_account_by_id_dataloader, create_dataloaders
container = aioinject.Container()
container.register(aioinject.Scoped(create_account_by_id_dataloader))
container.register(aioinject.Scoped(create_dataloaders))
We can plug the dataloaders into the request context in the following way:
from typing import Annotated
from aioinject import Injected
from aioinject.ext.fastapi import inject
from fastapi import BackgroundTasks, Request, Response
from strawberry.fastapi import GraphQLRouter
from .context import Context
from .dataloaders import Dataloaders
from .schema import schema
@inject
async def get_context(
request: Request,
response: Response,
background_tasks: BackgroundTasks,
dataloaders: Injected[Dataloaders],
) -> BaseContext:
"""Get the context for the GraphQL request."""
return Context(
request=request,
response=response,
background_tasks=background_tasks,
loaders=dataloaders,
)
def create_graphql_router() -> GraphQLRouter:
"""Create a GraphQL router."""
return GraphQLRouter(
schema=schema,
context_getter=get_context,
graphql_ide="graphiql",
)
Now, this GraphQL router can be mounted onto your FastAPI application.
✅ Benefits
- Static typecheckers are happy, we aren't violating any typing rules.
- We are leveraging the powers of the DI framework to load resources efficiently.
❌ Limitations
Lets say we have two dataloaders that load the same entity, with they same type of loading key:
- load_user_by_id -> loads the user by a string attribute
- load_user_by_username -> also loads the user by a string attribute
Because Aioinject resolves dependencies based on type hints, we will have a duplicate dependency provider return type annotation, and dependency resolution will fail.
Iteration 3- Introduce type aliases for data loaders
Let us address the limitations in the previous iteration using type aliases, a new feature added since Python 3.11.
We can create type aliases for data loader dependency providers like follows:
from strawberry.dataloader import DataLoader
from app.accounts.repositories import AccountRepo
from app.context import Dataloaders
from .models import Account
type AccountByIdLoader = DataLoader[str, Account | None]
async def create_account_by_id_dataloader(
account_repo: AccountRepo,
) -> AccountByIdLoader:
"""Create a dataloader to load accounts by their IDs."""
async def load_fn(account_ids: list[str]) -> list[Account | None]:
"""Load multiple entities by their keys."""
id_to_account_map = {
account.id: account
for account in await account_repo.get_many_by_ids(account_ids)
}
return [id_to_account_map.get(account_id) for account_id in account_ids]
return Dataloader(load_fn=load_fn)
def create_dataloaders(
account_by_id: AccountByIdLoader,
) -> Dataloaders:
"""Create dataloaders for the current context."""
return Dataloaders(
account_by_id=account_by_id,
)
The use of type aliases provides an unique return type annotation for every dependency provider function, which plays nicely with the Aioinject dependency injection system.
The remaining steps are similar to the previous iteration.
We register the dependency providers in the Aioinject container:
import aioinject
from app.dependencies import create_account_by_id_dataloader, create_dataloaders
container = aioinject.Container()
container.register(aioinject.Scoped(create_account_by_id_dataloader))
container.register(aioinject.Scoped(create_dataloaders))
We can plug the dataloaders into the request context in the following way:
from typing import Annotated
from aioinject import Injected
from aioinject.ext.fastapi import inject
from fastapi import BackgroundTasks, Request, Response
from strawberry.fastapi import GraphQLRouter
from .context import Context
from .dataloaders import Dataloaders
from .schema import schema
@inject
async def get_context(
request: Request,
response: Response,
background_tasks: BackgroundTasks,
dataloaders: Injected[Dataloaders],
) -> BaseContext:
"""Get the context for the GraphQL request."""
return Context(
request=request,
response=response,
background_tasks=background_tasks,
loaders=dataloaders,
)
def create_graphql_router() -> GraphQLRouter:
"""Create a GraphQL router."""
return GraphQLRouter(
schema=schema,
context_getter=get_context,
graphql_ide="graphiql",
)
Now, this GraphQL router can be mounted onto your FastAPI application.
✅ Benefits
- Static typecheckers are happy, we aren't violating any typing rules.
- We are leveraging the powers of the DI framework to load resources efficiently.
- We can have multiple data loaders with the same type annotations, uniquely identified by type aliases.
❌ Limitations
- This approach might be a little verbose to implement
Bonus- Transforming dataloader keys
In certain cases, you might want to transform the data loader keys before passing them to the data fetching function. You might want to convert your keys to an integer type, or an ObjectId (When you are using MongoDB).
Moving the key transformation logic inside your dataloaders can prevent you from defining the same redundant transformation logic in multiple places outside the loaders.
We can implement dataloader key transforms before data loading like this:
Integer Key Transformations
from strawberry.dataloader import DataLoader
from app.accounts.repositories import AccountRepo
from .models import Account
type AccountByIdLoader = DataLoader[str, Account | None]
def is_valid_integer(value: str) -> bool:
try:
int(value)
return True
except ValueError:
return False
async def create_account_by_id_dataloader(
account_repo: AccountRepo,
) -> AccountByIdLoader:
"""Create a dataloader to load accounts by their IDs."""
async def load_fn(account_ids: list[str]) -> list[Account | None]:
"""Load multiple entities by their keys."""
# transform keys here
valid_keys = [int(key) if is_valid_integer(key) for key in account_ids]
id_to_account_map = {
account.id: account
for account in await account_repo.get_many_by_ids(valid_keys)
}
return [id_to_account_map.get(int(account_id))) for account_id in account_ids if is_valid_integer(account_id) else None]
return Dataloader(load_fn=load_fn)
ObjectId Key Transformations
from bson import ObjectId
from strawberry.dataloader import DataLoader
from app.accounts.repositories import AccountRepo
from .models import Account
type AccountByIdLoader = DataLoader[str, Account | None]
async def create_account_by_id_dataloader(
account_repo: AccountRepo,
) -> AccountByIdLoader:
"""Create a dataloader to load accounts by their IDs."""
async def load_fn(account_ids: list[str]) -> list[Account | None]:
"""Load multiple entities by their keys."""
# transform keys here
valid_keys = [ObjectId(key) if ObjectId.is_valid(key) for key in account_ids]
id_to_account_map = {
account.id: account
for account in await account_repo.get_many_by_ids(valid_keys)
}
return [id_to_account_map.get(ObjectId(account_id))) for account_id in account_ids if ObjectId.is_valid(account_id) else None]
return Dataloader(load_fn=load_fn)
Creating a reusable dataloader factory
Writing the same key conversion, data fetching and transformation logic in every dataloader is repetitive. We can make life easier by creating a reusable dataloader factory, which can be used to define all of our dataloaders:
from collections.abc import Awaitable, Callable
from typing import TypeVar
from strawberry.dataloader import DataLoader
T = TypeVar("T")
U = TypeVar(
"U", str, tuple[str, str]
) # the original key type (input), assumed to be a string
K = TypeVar(
"K", str, tuple[str, str]
)
async def load_many_entities(
keys: list[U],
data_fetcher: Callable[[list[K]], Awaitable[list[T | None]]],
key_transform: Callable[[U], K | None],
) -> list[T | None]:
"""
Load entities by keys (IDs, slugs, etc.).
:param keys: A list of keys (e.g., IDs or slugs) as strings.
:param data_fetcher: The repository method to fetch data.
:param key_transform: Function to transform keys (e.g., convert to int).
:return: A list of entities matching the keys, preserving the original order.
"""
# Transform and validate keys
valid_keys: list[K] = [
key for key in (key_transform(key) for key in keys) if key is not None
]
# Fetch data using the provided repo method
fetched_entities = await data_fetcher(valid_keys)
# Map results back to original keys
key_to_entity_map = dict(zip(valid_keys, fetched_entities, strict=False))
# Return entities in the original key order, with None for invalid/missing keys
return [
key_to_entity_map.get(transformed_key)
if (transformed_key := key_transform(key)) is not None
else None
for key in keys
]
def transform_default(key: str) -> str | None:
"""Return the key as is."""
return key
def create_dataloader(
data_fetcher: Callable[[list[K]], Awaitable[list[T | None]]],
key_transform: Callable[[U], K | None],
) -> DataLoader[U, T | None]:
async def load_entities(entity_keys: list[U]) -> list[T | None]:
"""Load multiple entities by their keys."""
return await load_many_entities(
keys=entity_keys,
data_fetcher=data_fetcher,
key_transform=key_transform,
)
return DataLoader(load_fn=load_entities)
This dataloader factory function can be used to create multiple dataloaders with custom key transformation functions and data fetchers.
NOTE
The data fetcher function must return a list with the exact number of items (or none) as the number of keys that were passed. Otherwise, the data loading mechanism will error.
Here are a few examples:
from bson import ObjectId
from strawberry.dataloader import DataLoader
from app.accounts.repositories import AccountRepo, ProfileRepo
from app.core.dataloaders import (
create_dataloader,
transform_default
)
from .documents import Account, Profile
def transform_valid_object_id(key: str) -> ObjectId | None:
"""Check if a string is a valid ObjectId."""
return ObjectId(key) if ObjectId.is_valid(key) else None
type AccountByIdLoader = DataLoader[str, Account | None]
async def create_account_by_id_dataloader(
account_repo: AccountRepo,
) -> AccountByIdLoader:
"""Create a dataloader to load accounts by their IDs."""
return create_dataloader(
data_fetcher=account_repo.get_many_by_ids,
key_transform=transform_valid_object_id,
)
type AccountByUsernameLoader = DataLoader[str, Account | None]
async def create_account_by_username_dataloader(
account_repo: AccountRepo,
) -> AccountByUsernameLoader:
"""Create a dataloader to load accounts by their usernames."""
return create_dataloader(
data_fetcher=account_repo.get_many_by_usernames,
key_transform=transform_default,
)
type ProfileByIdLoader = DataLoader[str, Profile | None]
async def create_profile_by_id_dataloader(
profile_repo: ProfileRepo,
) -> ProfileByIdLoader:
"""Create a dataloader to load profiles by their IDs."""
return create_dataloader(
data_fetcher=profile_repo.get_many_by_ids,
key_transform=transform_valid_object_id,
)