def _visit_conversion( self, tp: AnyType, conversion: Serialization, dynamic: bool, next_conversion: Optional[AnyConversion], ) -> SerializationMethod: with context_setter(self) as setter: if conversion.fall_back_on_any is not None: setter._fall_back_on_any = conversion.fall_back_on_any if conversion.exclude_unset is not None: setter._exclude_unset = conversion.exclude_unset serialize_conv = self.visit_with_conv( conversion.target, sub_conversion(conversion, next_conversion)) converter = cast(Converter, conversion.converter) if converter is identity: method = serialize_conv elif serialize_conv is identity: return converter else: def method(obj: Any) -> Any: return serialize_conv(converter(obj)) return self._wrap(get_origin_or_type(tp), method)
def _check_flattened_schema(self, cls: Type, field: ObjectField): assert field.flattened with context_setter(self): self._ignore_first_ref = True if self.visit_field(field).get("type") not in { JsonType.OBJECT, "object" }: raise TypeError( f"Flattened field {cls.__name__}.{field.name} must have an object type" )
def mapping(self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType) -> JsonSchema: with context_setter(self): self._ignore_first_ref = True key = self.visit(key_type) if key["type"] != JsonType.STRING: raise ValueError("Mapping types must string-convertible key") value = self.visit(value_type) if "pattern" in key: return json_schema(type=JsonType.OBJECT, patternProperties={key["pattern"]: value}) else: return json_schema(type=JsonType.OBJECT, additionalProperties=value)
def _visit_flattened( self, field: ObjectField) -> TypeFactory[graphql.GraphQLOutputType]: get_prev_flattened = (self.get_flattened if self.get_flattened is not None else identity) field_name = field.name partial_serialize = self._field_serialization_method(field) def get_flattened(obj): return partial_serialize( getattr(get_prev_flattened(obj), field_name)) with context_setter(self) as setter: setter.get_flattened = get_flattened return self.visit_with_conv(field.type, field.serialization)
def _properties_schema(self, field: ObjectField) -> JsonSchema: assert field.pattern_properties is not None or field.additional_properties with context_setter(self): self._ignore_first_ref = True props_schema = self.visit_field(field) if not props_schema.get("type") == JsonType.OBJECT: raise TypeError("properties field must have an 'object' type") if "patternProperties" in props_schema: if (len(props_schema["patternProperties"]) != 1 or "additionalProperties" in props_schema): # don't try to merge the schemas pass else: return next(iter(props_schema["patternProperties"].values())) elif "additionalProperties" in props_schema: if isinstance(props_schema["additionalProperties"], JsonSchema): return props_schema["additionalProperties"] else: # there is maybe only properties pass return JsonSchema()
def _replace_conversion(self, conversion: Optional[AnyConversion]): with context_setter(self) as setter: setter._conversions = resolve_any_conversion(conversion) yield
def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMethod: with context_setter(self) as setter: setter._allow_undefined = True normal_fields, aggregate_fields = [], [] for field in fields: serialize_field = self.visit_with_conv(field.type, field.serialization) if field.is_aggregate: aggregate_fields.append((field.name, serialize_field)) else: normal_fields.append( (field.name, self.aliaser(field.alias), serialize_field)) serialized_methods = [( self.aliaser(name), method.func, self.visit_with_conv(types["return"], method.conversion), ) for name, (method, types) in get_serialized_methods(tp).items()] exclude_unset = self._exclude_unset def method( obj: Any, attr_getter=getattr, normal_fields=normal_fields, aggregate_fields=aggregate_fields, ) -> Any: result = {} # aggregate before normal fields to avoid overloading for name, field_method in aggregate_fields: attr = attr_getter(obj, name) result.update(field_method(attr)) for name, alias, field_method in normal_fields: attr = attr_getter(obj, name) if attr is not Undefined: result[alias] = field_method(attr) for alias, func, method in serialized_methods: res = func(obj) if res is not Undefined: result[alias] = method(res) return result cls = get_origin_or_type(tp) if is_typed_dict(cls): cls, exclude_unset = Mapping, False wrapped_attr_getter = method def method( obj: Any, attr_getter=getattr, normal_fields=normal_fields, aggregate_fields=aggregate_fields, ) -> Any: return wrapped_attr_getter(obj, type(obj).__getitem__, normal_fields, aggregate_fields) if exclude_unset: wrapped_exclude_unset = method def method( obj: Any, attr_getter=getattr, normal_fields=normal_fields, aggregate_fields=aggregate_fields, ) -> Any: if hasattr(obj, FIELDS_SET_ATTR): fields_set_ = fields_set(obj) normal_fields = [(name, alias, method) for (name, alias, method) in normal_fields if name in fields_set_] aggregate_fields = [(name, method) for (name, method) in aggregate_fields if name in fields_set_] return wrapped_exclude_unset(obj, attr_getter, normal_fields, aggregate_fields) return self._wrap_type_check(cls, method)
def _visit_conversion( self, tp: AnyType, conversion: Deserialization, dynamic: bool, next_conversion: Optional[AnyConversion], ) -> DeserializationMethodFactory: assert conversion conv_factories = [] for conv in conversion: with context_setter(self) as setter: if conv.additional_properties is not None: setter._additional_properties = conv.additional_properties if conv.fall_back_on_default is not None: setter._fall_back_on_default = conv.fall_back_on_default setter._coerce, setter._coercer = get_coercer( conv.coerce, self._coerce, self._coercer) sub_conv = sub_conversion(conv, next_conversion) conv_factories.append( self.visit_with_conv(conv.source, sub_conv)) def factory(constraints: Optional[Constraints], validators: Sequence[Validator]) -> DeserializationMethod: conv_factories2 = conv_factories if not dynamic: conv_factories2 = [ fact.merge(constraints, validators) for fact in conv_factories ] conv_deserializers = [ (fact.method, conv.converter) for conv, fact in zip(conversion, conv_factories2) ] method: DeserializationMethod if len(conv_deserializers) > 1: def method(data: Any) -> Any: error: Optional[ValidationError] = None for deserialize_conv, converter in conv_deserializers: try: value = deserialize_conv(data) break except ValidationError as err: error = merge_errors(error, err) else: assert error is not None raise error try: return converter(value) # type: ignore except (ValidationError, AssertionError): raise except Exception as err: raise ValidationError([str(err)]) elif conv_deserializers[0][1] is identity: method, _ = conv_deserializers[0] else: conv_deserializer, converter = conv_deserializers[0] def method(data: Any) -> Any: try: return converter( conv_deserializer(data)) # type: ignore except (ValidationError, AssertionError): raise except Exception as err: raise ValidationError([str(err)]) return method return DeserializationMethodFactory(factory)