def deserialization() -> Conversion: annotations: dict[str, Any] = {} deserialization_namespace: dict[str, Any] = { "__annotations__": annotations } for sub in rec_subclasses(cls): annotations[sub.__name__] = Tagged[sub] # type: ignore # Add tagged fields for all its alternative constructors for constructor in _alternative_constructors.get(sub, ()): # Build the alias of the field alias = to_pascal_case(constructor.__name__) # object_deserialization uses get_type_hints, but the constructor # return type is stringified and the class not defined yet, # so it must be assigned manually constructor.__annotations__["return"] = sub # Add constructor tagged field with its conversion annotations[alias] = Tagged[sub] # type: ignore deserialization_namespace[alias] = Tagged( conversion( # Use object_deserialization to wrap constructor as deserializer deserialization=object_deserialization( constructor, type_name(alias)))) # Create the deserialization tagged union class deserialization_union = new_class( cls.__name__, (TaggedUnion, ), exec_body=lambda ns: ns.update(deserialization_namespace), ) return Conversion(lambda obj: get_tagged(obj)[1], source=deserialization_union, target=cls)
def conversion(self) -> Optional[Conversion]: if self.serialization: # Recursive conversion pattern tmp = None conversion = Conversion(self.serialization, sub_conversion=LazyConversion(lambda: tmp)) tmp = conversion return conversion else: return None
def default_deserialization(tp: Any) -> Optional[AnyConversion]: if inspect.isclass(tp) and issubclass(tp, pydantic.BaseModel): def deserialize_pydantic(data): try: return tp.parse_obj(data) except pydantic.ValidationError as error: raise ValidationError.from_errors(error.errors()) return Conversion( deserialize_pydantic, source=tp.__annotations__.get("__root__", Mapping[str, Any]), target=tp, ) else: return prev_deserialization(tp)
def serialization() -> Conversion: serialization_union = new_class( cls.__name__, (TaggedUnion, ), exec_body=lambda ns: ns.update({ "__annotations__": { sub.__name__: Tagged[sub] for sub in rec_subclasses(cls) # type: ignore } }), ) return Conversion( lambda obj: serialization_union(**{obj.__class__.__name__: obj}), source=cls, target=serialization_union, # Conversion must not be inherited because it would lead to infinite # recursion otherwise inherited=False, )
from apischema.conversions import Conversion @dataclass class RGB: red: int green: int blue: int @serializer @property def hexa(self) -> str: return f"#{self.red:02x}{self.green:02x}{self.blue:02x}" assert serialize(RGB, RGB(0, 0, 0)) == "#000000" # dynamic conversion used to bypass the registered one assert serialize(RGB, RGB(0, 0, 0), conversion=identity) == { "red": 0, "green": 0, "blue": 0, } # Expended bypass form assert serialize(RGB, RGB(0, 0, 0), conversion=Conversion(identity, source=RGB, target=RGB)) == { "red": 0, "green": 0, "blue": 0 }
PurePath, PurePosixPath, PureWindowsPath, WindowsPath, ) from typing import Deque, List, TypeVar from uuid import UUID from apischema import deserializer, schema, serializer, type_name from apischema.conversions import Conversion, as_str T = TypeVar("T") # =================== bytes ===================== deserializer(Conversion(b64decode, source=str, target=bytes)) @serializer def to_base64(b: bytes) -> str: return b64encode(b).decode() type_name(graphql="Bytes")(bytes) schema(encoding="base64")(bytes) # ================ collections ================== deserializer(Conversion(deque, source=List[T], target=Deque[T])) serializer(Conversion(list, source=Deque[T], target=List[T])) if sys.version_info < (3, 7):
def __init_subclass__(cls): # Deserializers stack directly as a Union deserializer(Conversion(identity, source=cls, target=Arg))
def handle_enum(tp: AnyType) -> Optional[AnyConversion]: if is_type(tp) and issubclass(tp, Enum): return Conversion(identity, source=Any, target=tp) return default_deserialization(tp)
from base64 import b64decode from apischema import deserialize, deserializer from apischema.conversions import Conversion deserializer(Conversion(b64decode, source=str, target=bytes)) # Roughly equivalent to: # def decode_bytes(source: str) -> bytes: # return b64decode(source) # but saving a function call assert deserialize(bytes, "Zm9v") == b"foo"