💾 Archived View for separateconcerns.com › 2024-05-28-graphql-dataloaders.gmi captured on 2024-08-24 at 23:54:33. Gemini links have been rewritten to link to archived content
⬅️ Previous capture (2024-06-16)
-=-=-=-=-=-=-
published 2024-05-28
If you prefer, the code for this article is also available in a gist [1].
GraphQL is great because it lets consumers define the schema of their requests, however that makes it somehow harder to optimize than typical REST requests. If you implement GraphQL navively, you will quickly get the infamous N+1 issue. Let me demonstrate the issue with an example project using Quart [2], Quart-DB [3], PostgreSQL [4] and Strawberry [5].
Let us start by defining a simple schema for a music collection and populating the database with a migration:
from quart_db import Connection async def create_schema(cnx: Connection) -> None: await cnx.execute( """ CREATE TABLE bands ( id bigint PRIMARY KEY, name text NOT NULL ); CREATE TABLE albums ( id bigint PRIMARY KEY, name text NOT NULL, band_id bigint REFERENCES bands(id) ); CREATE TABLE songs ( id bigint GENERATED ALWAYS AS IDENTITY PRIMARY KEY, name text NOT NULL, album_id bigint REFERENCES albums(id) ); """ ) async def populate(cnx: Connection) -> None: await cnx.execute( """ INSERT INTO bands (id, name) VALUES (1, 'Dark Tranquillity'), (2, 'Pineapple Thief'), (3, 'Wintersun'); INSERT INTO albums (id, name, band_id) VALUES (1, 'Haven', 1), (2, 'Fiction', 1), (3, 'Atoma', 1), (4, 'Time I', 3), (5, 'Your Wilderness', 2), (6, 'Versions of the Truth', 2); INSERT INTO songs (name, album_id) VALUES ('The Wonders at Your Feet', 1), ('Not Built to Last', 1), ('Indifferent Suns', 1), ('At Loss for Words', 1), ('Terminus', 2), ('Inside the Particle Storm', 2), ('Focus Shift', 2), ('Forward Momentum', 3), ('Caves and Embers', 3), ('When Mountains Fall', 4), ('Sons of Winter and Stars', 4), ('Land of Snow and Sorrow', 4), ('Time', 4), ('The Final Thing on My Mind', 5), ('Tear You Up', 5); """ ) async def migrate(cnx: Connection) -> None: await create_schema(cnx) await populate(cnx)
We expose bands, albums and songs; now let us build a GraphQL interface to query them the naive way:
import logging from typing import Any import strawberry from quart_cors import cors from quart_db import QuartDB from strawberry.quart.views import GraphQLView as QuartGraphQLView from quart import Quart, Request, Response app = Quart("sfdl") app.logger.setLevel(logging.INFO) app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl" app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False db = QuartDB(app) cors(app, allow_origin="*", allow_methods=["GET", "POST"]) @strawberry.type class Song: id: int name: str album_id: int @strawberry.type class Album: id: int name: str band_id: int @strawberry.field async def songs(self) -> list[Song]: query = """ SELECT id, name, album_id FROM songs WHERE album_id = :album_id """ async with db.connection() as cnx: result = await cnx.fetch_all(query, {"album_id": self.id}) songs = [Song(**row) for row in result] app.logger.info(f"Got {len(songs)} songs.") return songs @strawberry.type class Band: id: int name: str @strawberry.field async def albums(self) -> list[Album]: query = """ SELECT id, name, band_id FROM albums WHERE band_id = :band_id """ async with db.connection() as cnx: result = await cnx.fetch_all(query, {"band_id": self.id}) albums = [Album(**row) for row in result] app.logger.info(f"Got {len(albums)} albums.") return albums @strawberry.type class Query: @strawberry.field async def bands(self) -> list[Band]: query = """ SELECT id, name FROM bands """ async with db.connection() as cnx: result = await cnx.fetch_all(query) bands = [Band(**row) for row in result] app.logger.info(f"Got {len(bands)} bands.") return bands class GraphQLView(QuartGraphQLView): async def get_context(self, request: Request, response: Response) -> dict[str, Any]: return {"request": request, "response": response} view = GraphQLView.as_view( "graphql_view", schema=strawberry.Schema(query=Query), graphql_ide="graphiql", ) app.add_url_rule("/", view_func=view)
We can try it with GraphiQL, to check that it works:
music collection request in GraphiQL
It does, but if we look at the logs we see this:
INFO in __init__: Got 3 bands. INFO in __init__: Got 2 albums. INFO in __init__: Got 1 albums. INFO in __init__: Got 3 albums. INFO in __init__: Got 2 songs. INFO in __init__: Got 4 songs. INFO in __init__: Got 0 songs. INFO in __init__: Got 4 songs. INFO in __init__: Got 2 songs. INFO in __init__: Got 3 songs. [59520] [INFO] 127.0.0.1:52944 POST / 1.1 200 806 5311
It does too many requests: first it gets all the artists, then for each artist it gets the albums, then for each album it gets the songs... That is the N+1 problem.
If you look it up online, you will quickly find out that a popular solution is dataloaders. However, a lot of examples you will find do not correspond exactly to that problem, because they only demonstrate how to use dataloaders on primary keys, whereas here we want to load on a foreign key. This is the case of the official Strawberry documentation [6], which you should still read too if you use that library.
The important thing to understand with dataloaders is that they take a list of keys and return a list of answers of the same size, correponding to those keys. So you cannot, for instance, have a dataloader that takes a list of album IDs and returns a list of all the songs in those albums. However, the trick is that you can have a dataloader that returns a list of lists of songs corresponding to each album!
Now let us rewrite our example with that pattern:
import logging from collections import defaultdict from functools import cached_property from typing import Any import strawberry from quart_cors import cors from quart_db import QuartDB from strawberry.dataloader import DataLoader from strawberry.quart.views import GraphQLView as QuartGraphQLView from strawberry.types import Info from quart import Quart, Request, Response app = Quart("sfdl") app.logger.setLevel(logging.INFO) app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl" app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False db = QuartDB(app) cors(app, allow_origin="*", allow_methods=["GET", "POST"]) @strawberry.type class Song: id: int name: str album_id: int @strawberry.type class Album: id: int name: str band_id: int @strawberry.field async def songs(self, info: Info) -> list[Song]: dl = info.context["dataloaders"].songs_for_albums return await dl.load(self.id) @strawberry.type class Band: id: int name: str @strawberry.field async def albums(self, info: Info) -> list[Album]: dl = info.context["dataloaders"].albums_for_bands return await dl.load(self.id) @strawberry.type class Query: @strawberry.field async def bands(self) -> list[Band]: query = """ SELECT id, name FROM bands """ async with db.connection() as cnx: result = await cnx.fetch_all(query) bands = [Band(**row) for row in result] app.logger.info(f"Got {len(bands)} bands.") return bands class DataLoaders: @staticmethod async def load_songs_for_albums(keys: list[int]) -> list[list[Song]]: query = """ SELECT id, name, album_id FROM songs WHERE album_id = ANY(:keys) """ async with db.connection() as cnx: result = await cnx.fetch_all(query, {"keys": keys}) songs = [Song(**row) for row in result] app.logger.info(f"Got {len(songs)} songs.") by_key: defaultdict[int, list[Song]] = defaultdict(list) for song in songs: by_key[song.album_id].append(song) return [by_key[k] for k in keys] @staticmethod async def load_albums_for_bands(keys: list[int]) -> list[list[Album]]: query = """ SELECT id, name, band_id FROM albums WHERE band_id = ANY(:keys) """ async with db.connection() as cnx: result = await cnx.fetch_all(query, {"keys": keys}) albums = [Album(**row) for row in result] app.logger.info(f"Got {len(albums)} albums.") by_key: defaultdict[int, list[Album]] = defaultdict(list) for album in albums: by_key[album.band_id].append(album) return [by_key[k] for k in keys] @cached_property def songs_for_albums(self) -> DataLoader[int, list[Song]]: return DataLoader(self.load_songs_for_albums) @cached_property def albums_for_bands(self) -> DataLoader[int, list[Album]]: return DataLoader(self.load_albums_for_bands) class GraphQLView(QuartGraphQLView): async def get_context(self, request: Request, response: Response) -> dict[str, Any]: return {"request": request, "response": response, "dataloaders": DataLoaders()} view = GraphQLView.as_view( "graphql_view", schema=strawberry.Schema(query=Query), graphql_ide="graphiql", ) app.add_url_rule("/", view_func=view)
The request for the whole collection still works, but now if we look at the logs we can see this:
INFO in __init__: Got 3 bands. INFO in __init__: Got 6 albums. INFO in __init__: Got 15 songs. [59520] [INFO] 127.0.0.1:60244 POST / 1.1 200 806 4509
No more N+1, our job is done here.
1: https://gist.github.com/catwell/e47e70b47550ba2fb07d04a41bb8baf0
2: https://quart.palletsprojects.com/