Skip to content

Instantly share code, notes, and snippets.

@kmatarese
Last active January 23, 2024 06:48
Show Gist options
  • Save kmatarese/a5492f4a02449e13ea85ace8801b8dfb to your computer and use it in GitHub Desktop.
Save kmatarese/a5492f4a02449e13ea85ace8801b8dfb to your computer and use it in GitHub Desktop.
Hack to convert marshmallow schemas to pydantic models
"""WARNING: not thoroughly tested and does not support full translation
between the two libraries.
Uses a pydantic root_validator to init the marshmallow schema. It attempts
to map marshmallow field types to pydantic field types as well, but not all
field types are supported.
You can either use the pydantic_from_marshmallow function that does all of
the above or just subclass MarshmallowModel and manually define your pydantic
fields/types/etc.
"""
from datetime import date, datetime, timedelta, time
from decimal import Decimal
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from marshmallow import Schema, fields, missing
from pprint import pprint
from pydantic import (
BaseModel,
root_validator,
create_model,
AnyUrl,
EmailStr,
StrictFloat,
StrictInt,
)
from pydantic.utils import validate_field_name
CUSTOM_FIELD_DEFAULT = Any
# Fields in the marshmallow schema may fail the call to pydantic's
# validate_field_name if they conflict with base fields. To work around this
# we mark illegal fields with this and then strip it later to create an alias
# using an alias_generator. Bleh.
ALIAS_MARKER = "__alias__"
def get_dict_type(x):
"""For dicts we need to look at the key and value type"""
key_type = get_pydantic_type(x.key_field)
if x.value_field:
value_type = get_pydantic_type(x.value_field)
return Dict[key_type, value_type]
return Dict[key_type, Any]
def get_list_type(x):
"""For lists we need to look at the value type"""
if x.inner:
c_type = get_pydantic_type(x.inner)
return List[c_type]
return List
def get_nested_model(x):
"""Return a model from a nested marshmallow schema"""
return pydantic_from_marshmallow(x.schema)
FIELD_CONVERTERS = {
fields.Bool: bool,
fields.Boolean: bool,
fields.Date: date,
fields.DateTime: datetime,
fields.Decimal: Decimal,
fields.Dict: get_dict_type,
fields.Email: EmailStr,
fields.Float: float,
fields.Function: Callable,
fields.Int: int,
fields.Integer: int,
fields.List: get_list_type,
fields.Mapping: Mapping,
fields.Method: Callable,
fields.Nested: get_nested_model,
fields.Number: Union[StrictFloat, StrictInt],
fields.Str: str,
fields.String: str,
fields.Time: time,
fields.TimeDelta: timedelta,
fields.URL: AnyUrl,
fields.Url: AnyUrl,
fields.UUID: str,
}
def is_custom_field(field):
"""If this is a subclass of marshmallow's Field and not in our list, we
assume its a custom field"""
ftype = type(field)
if issubclass(ftype, fields.Field) and ftype not in FIELD_CONVERTERS:
return True
return False
def get_pydantic_type(field):
"""Get pydantic type from a marshmallow field"""
if is_custom_field(field):
conv = Any
else:
conv = FIELD_CONVERTERS[type(field)]
# TODO: Is there a cleaner way to check for annotation types?
if isinstance(conv, type) or conv.__module__ == "typing":
pyd_type = conv
else:
pyd_type = conv(field)
if not field.required:
pyd_type = Optional[pyd_type]
return pyd_type
def is_valid_field_name(bases, x):
try:
validate_field_name(bases, x)
return True
except NameError as e:
return False
def get_alias(x):
if x.endswith(ALIAS_MARKER):
return x.replace(ALIAS_MARKER, "")
return x
class MarshmallowModel(BaseModel):
"""A pydantic model that uses a marshmallow schema for object-wide validation"""
_schema = None
@root_validator(pre=True)
def _schema_validate(cls, values):
if not cls._schema:
raise AssertionError("Must define a marshmallow schema")
return cls._schema().load(values)
class Config:
alias_generator = get_alias
def pydantic_from_marshmallow(schema):
"""Convert a marshmallow schema to a pydantic model. May only
work for fairly simple cases. Barely tested. Enjoy."""
pyd_fields = {}
for field_name, field in schema._declared_fields.items():
pyd_type = get_pydantic_type(field)
default = field.default if field.default != missing else None
if not is_valid_field_name([BaseModel], field_name):
field_name = field_name + ALIAS_MARKER
pyd_fields[field_name] = (pyd_type, default)
if isinstance(schema, Schema):
name = schema.__class__.__name__
else:
name = schema.__name__
return create_model(name, _schema=schema, **pyd_fields, __base__=MarshmallowModel)
if __name__ == "__main__":
# Simple test...
def is_valid_str(val):
if not isinstance(val, str):
raise ValidationError(f"value is not a string: {val}")
return val
class MyField(fields.Field):
pass
class TestSchema(Schema):
some_str = fields.String(required=True, validate=is_valid_str)
some_dict = fields.Dict(keys=fields.Str(), default=None, missing={})
some_list = fields.List(fields.Integer)
fields = fields.Str() # illegal field name for pydantic
class TestSubSchema(TestSchema):
some_int = fields.Integer(required=False, missing=5)
some_custom_field = MyField()
for schema in [TestSchema, TestSubSchema]:
model = pydantic_from_marshmallow(schema)
x = model(some_str="a string!")
print(f"\n{model}:{x}")
pprint(model.schema())
@StephenBrown2
Copy link

Some possible suggestions for FIELD_CONVERTERS:

fields.Email: pydantic.EmailStr
fields.Function: typing.Callable
fields.Mapping: typing.Mapping / dict?
fields.Nested: (Pydantic allows you to just use the model class here)
fields.Number: typing.Union[pydantic.StrictFloat, pydantic.StrictInt]
fields.UUID: uuid.UUID / pydantic.UUID[1-5]
fields.Url: pydantic.HttpUrl / pydantic.AnyUrl / pydantic.AnyHttpUrl
fields.URL: Alias for fields.Url

@kmatarese
Copy link
Author

kmatarese commented May 14, 2020

Thanks! I'll make some edits. I also just realized this is broken with the most recent version of marshmallow (3.6) as their interface changed at some point. My quick test was against 3.0, addressing that as well.

Update: added support based on your list and fixed a few bugs.

@scott-cognizant
Copy link

Two improvements I was able to make are properly handling recursion and using the data_key name for cases when it is set (particularly in uses in dataclasses_json in which we use camel case conversion). Happy to share if there is a preferred method.

@totalhack
Copy link

@scott-cognizant not sure if gists allow an easy way to share but feel free to post a link to your fork otherwise.

@scott-cognizant
Copy link

Thank you for the suggestion (given lack of traditional PRs for gists). I have added a fork at:
https://gist.github.com/scott-cognizant/3270a19c7b1c278e440883119667cc62
The comment describes the relevant changes (including support for recursion). Thought I'd share just in case any or all seemed worth including in the original. Thanks for the helpful util!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment