예제 #1
0
    def _visit_conversion(
        self,
        tp: AnyType,
        conversion: Serialization,
        dynamic: bool,
        next_conversion: Optional[AnyConversion],
    ) -> SerializationMethod:
        with context_setter(self) as setter:
            if conversion.fall_back_on_any is not None:
                setter._fall_back_on_any = conversion.fall_back_on_any
            if conversion.exclude_unset is not None:
                setter._exclude_unset = conversion.exclude_unset
            serialize_conv = self.visit_with_conv(
                conversion.target, sub_conversion(conversion, next_conversion))

        converter = cast(Converter, conversion.converter)
        if converter is identity:
            method = serialize_conv
        elif serialize_conv is identity:
            return converter
        else:

            def method(obj: Any) -> Any:
                return serialize_conv(converter(obj))

        return self._wrap(get_origin_or_type(tp), method)
예제 #2
0
 def _check_flattened_schema(self, cls: Type, field: ObjectField):
     assert field.flattened
     with context_setter(self):
         self._ignore_first_ref = True
         if self.visit_field(field).get("type") not in {
                 JsonType.OBJECT, "object"
         }:
             raise TypeError(
                 f"Flattened field {cls.__name__}.{field.name} must have an object type"
             )
예제 #3
0
 def mapping(self, cls: Type[Mapping], key_type: AnyType,
             value_type: AnyType) -> JsonSchema:
     with context_setter(self):
         self._ignore_first_ref = True
         key = self.visit(key_type)
     if key["type"] != JsonType.STRING:
         raise ValueError("Mapping types must string-convertible key")
     value = self.visit(value_type)
     if "pattern" in key:
         return json_schema(type=JsonType.OBJECT,
                            patternProperties={key["pattern"]: value})
     else:
         return json_schema(type=JsonType.OBJECT,
                            additionalProperties=value)
예제 #4
0
    def _visit_flattened(
            self,
            field: ObjectField) -> TypeFactory[graphql.GraphQLOutputType]:
        get_prev_flattened = (self.get_flattened
                              if self.get_flattened is not None else identity)
        field_name = field.name
        partial_serialize = self._field_serialization_method(field)

        def get_flattened(obj):
            return partial_serialize(
                getattr(get_prev_flattened(obj), field_name))

        with context_setter(self) as setter:
            setter.get_flattened = get_flattened
            return self.visit_with_conv(field.type, field.serialization)
예제 #5
0
 def _properties_schema(self, field: ObjectField) -> JsonSchema:
     assert field.pattern_properties is not None or field.additional_properties
     with context_setter(self):
         self._ignore_first_ref = True
         props_schema = self.visit_field(field)
     if not props_schema.get("type") == JsonType.OBJECT:
         raise TypeError("properties field must have an 'object' type")
     if "patternProperties" in props_schema:
         if (len(props_schema["patternProperties"]) != 1
                 or "additionalProperties"
                 in props_schema):  # don't try to merge the schemas
             pass
         else:
             return next(iter(props_schema["patternProperties"].values()))
     elif "additionalProperties" in props_schema:
         if isinstance(props_schema["additionalProperties"], JsonSchema):
             return props_schema["additionalProperties"]
         else:  # there is maybe only properties
             pass
     return JsonSchema()
예제 #6
0
 def _replace_conversion(self, conversion: Optional[AnyConversion]):
     with context_setter(self) as setter:
         setter._conversions = resolve_any_conversion(conversion)
         yield
예제 #7
0
    def object(self, tp: AnyType,
               fields: Sequence[ObjectField]) -> SerializationMethod:
        with context_setter(self) as setter:
            setter._allow_undefined = True
            normal_fields, aggregate_fields = [], []
            for field in fields:
                serialize_field = self.visit_with_conv(field.type,
                                                       field.serialization)
                if field.is_aggregate:
                    aggregate_fields.append((field.name, serialize_field))
                else:
                    normal_fields.append(
                        (field.name, self.aliaser(field.alias),
                         serialize_field))
            serialized_methods = [(
                self.aliaser(name),
                method.func,
                self.visit_with_conv(types["return"], method.conversion),
            ) for name, (method, types) in get_serialized_methods(tp).items()]
        exclude_unset = self._exclude_unset

        def method(
            obj: Any,
            attr_getter=getattr,
            normal_fields=normal_fields,
            aggregate_fields=aggregate_fields,
        ) -> Any:
            result = {}
            # aggregate before normal fields to avoid overloading
            for name, field_method in aggregate_fields:
                attr = attr_getter(obj, name)
                result.update(field_method(attr))
            for name, alias, field_method in normal_fields:
                attr = attr_getter(obj, name)
                if attr is not Undefined:
                    result[alias] = field_method(attr)
            for alias, func, method in serialized_methods:
                res = func(obj)
                if res is not Undefined:
                    result[alias] = method(res)
            return result

        cls = get_origin_or_type(tp)
        if is_typed_dict(cls):
            cls, exclude_unset = Mapping, False
            wrapped_attr_getter = method

            def method(
                obj: Any,
                attr_getter=getattr,
                normal_fields=normal_fields,
                aggregate_fields=aggregate_fields,
            ) -> Any:
                return wrapped_attr_getter(obj,
                                           type(obj).__getitem__,
                                           normal_fields, aggregate_fields)

        if exclude_unset:
            wrapped_exclude_unset = method

            def method(
                obj: Any,
                attr_getter=getattr,
                normal_fields=normal_fields,
                aggregate_fields=aggregate_fields,
            ) -> Any:
                if hasattr(obj, FIELDS_SET_ATTR):
                    fields_set_ = fields_set(obj)
                    normal_fields = [(name, alias, method)
                                     for (name, alias, method) in normal_fields
                                     if name in fields_set_]
                    aggregate_fields = [(name, method)
                                        for (name, method) in aggregate_fields
                                        if name in fields_set_]
                return wrapped_exclude_unset(obj, attr_getter, normal_fields,
                                             aggregate_fields)

        return self._wrap_type_check(cls, method)
예제 #8
0
    def _visit_conversion(
        self,
        tp: AnyType,
        conversion: Deserialization,
        dynamic: bool,
        next_conversion: Optional[AnyConversion],
    ) -> DeserializationMethodFactory:
        assert conversion
        conv_factories = []
        for conv in conversion:
            with context_setter(self) as setter:
                if conv.additional_properties is not None:
                    setter._additional_properties = conv.additional_properties
                if conv.fall_back_on_default is not None:
                    setter._fall_back_on_default = conv.fall_back_on_default
                setter._coerce, setter._coercer = get_coercer(
                    conv.coerce, self._coerce, self._coercer)
                sub_conv = sub_conversion(conv, next_conversion)
                conv_factories.append(
                    self.visit_with_conv(conv.source, sub_conv))

        def factory(constraints: Optional[Constraints],
                    validators: Sequence[Validator]) -> DeserializationMethod:
            conv_factories2 = conv_factories
            if not dynamic:
                conv_factories2 = [
                    fact.merge(constraints, validators)
                    for fact in conv_factories
                ]
            conv_deserializers = [
                (fact.method, conv.converter)
                for conv, fact in zip(conversion, conv_factories2)
            ]
            method: DeserializationMethod
            if len(conv_deserializers) > 1:

                def method(data: Any) -> Any:
                    error: Optional[ValidationError] = None
                    for deserialize_conv, converter in conv_deserializers:
                        try:
                            value = deserialize_conv(data)
                            break
                        except ValidationError as err:
                            error = merge_errors(error, err)
                    else:
                        assert error is not None
                        raise error
                    try:
                        return converter(value)  # type: ignore
                    except (ValidationError, AssertionError):
                        raise
                    except Exception as err:
                        raise ValidationError([str(err)])

            elif conv_deserializers[0][1] is identity:
                method, _ = conv_deserializers[0]
            else:
                conv_deserializer, converter = conv_deserializers[0]

                def method(data: Any) -> Any:
                    try:
                        return converter(
                            conv_deserializer(data))  # type: ignore
                    except (ValidationError, AssertionError):
                        raise
                    except Exception as err:
                        raise ValidationError([str(err)])

            return method

        return DeserializationMethodFactory(factory)