Seamless Pydantic-Celery Integration

I've been using both celery and pydantic for years and was looking for existing pydantic - celery integrations. I found some but was surpised to see that they weren't seamless. So I built on top of them and turned them into a 1 line integration.
Here's the code if you want to skip the explanation:
https://github.com/jwnwilson/celery_pydantic/You can also steal this file directly if you prefer:
https://github.com/jwnwilson/celery_pydantic/blob/main/celery_pydantic/serializer.pyFirst install the package:
pip install celery_pydantic
Then set it up like this:
from celery import Celery
from celery_pydantic import pydantic_celery
# Create your Celery app
app = Celery('myapp', broker='amqp://')
# Configure the app to use Pydantic serialization
pydantic_celery(app)
Now you can use pydantic models as celery task arguments.
from pydantic import BaseModel
class User(BaseModel):
name: str
age: int
@app.task
def process_user(user: User):
return user.name
process_user.delay(user=User(name="John", age=30))
You can also return pydantic models from tasks.
@app.task
def process_user(user: User):
return user
user: User = process_user.delay(user=User(name="John", age=30)).get()
Under the hood
The library is just one file with a serializer that hooks into celery's serialization, when a task is created we record the import path for pydantic models found in the args. When we run a task if the task data has pydantic class import path then it will dynamically load the pydantic model and parse the data.
Potential improvements
I kept the json parsing default and it could be improved with a faster json parser like orjson or ujson.
Comparison with existing logic
This blog post is the majority of the code above, but it requires registering each model manually, which I didn't want to do.
Celery's official Pydantic integration requires pydantic models to be converted to dicts using model_dump(). This will also error if you have types that don't work with the default json serializer such as `UUID`, `datetime`, etc. In the end it looks something like this:
@app.task(pydantic=True)
def process_user(user: User):
return user.name
process_user.delay(user=json.loads(User(name="John", age=30).model_dump_json()))
This works too if you prefer to keep things simple.
Let me know what you think
Do you know a better way to solve this? I'd love to hear your thoughts and learn more!