コード例 #1
0
 def _has_conversion(
     tp: AnyType, conversion: Optional[AnyConversion]
 ) -> Tuple[bool, Optional[Deserialization]]:
     identity_conv, result = False, []
     for conv in resolve_any_conversion(conversion):
         conv = handle_identity_conversion(conv, tp)
         if is_subclass(conv.target, tp):
             if is_identity(conv):
                 if identity_conv:
                     continue
                 identity_conv = True
                 wrapper: AnyType = self_deserialization_wrapper(
                     get_origin_or_type(tp))
                 if get_args(tp):
                     wrapper = wrapper[get_args(tp)]
                 conv = ResolvedConversion(replace(conv, source=wrapper))
             conv = handle_dataclass_model(conv)
             _, substitution = subtyping_substitution(tp, conv.target)
             source = substitute_type_vars(conv.source, substitution)
             result.append(
                 ResolvedConversion(replace(conv, source=source,
                                            target=tp)))
     if identity_conv and len(result) == 1:
         return True, None
     else:
         return bool(result), tuple(result) or None
コード例 #2
0
def check_converter_type(tp: AnyType) -> AnyType:
    origin = get_origin_or_type(tp)
    if not is_convertible(tp):
        raise TypeError(f"{origin} is not convertible")
    if not all(map(is_type_var, get_args2(tp))):
        raise TypeError("Generic conversion doesn't support specialization")
    return origin
コード例 #3
0
 def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> JsonSchema:
     result = super().object(tp, fields)
     name_by_aliases = {f.alias: f.name for f in fields}
     properties = {}
     required = []
     for alias, (serialized, types) in get_serialized_methods(tp).items():
         return_type = types["return"]
         properties[self.aliaser(alias)] = full_schema(
             self.visit_with_conv(return_type, serialized.conversion),
             serialized.schema,
         )
         if not is_union_of(return_type, UndefinedType):
             required.append(alias)
         name_by_aliases[alias] = serialized.func.__name__
     if "allOf" not in result:
         to_update = result
     else:
         to_update = result["allOf"][0]
     if required:
         required.extend(to_update.get("required", ()))
         to_update["required"] = sorted(required)
     if properties:
         properties.update(to_update.get("properties", {}))
         props = sort_by_annotations_position(get_origin_or_type(tp),
                                              properties,
                                              lambda p: name_by_aliases[p])
         to_update["properties"] = {p: properties[p] for p in props}
     return result
コード例 #4
0
    def _visit_conversion(
        self,
        tp: AnyType,
        conversion: Serialization,
        dynamic: bool,
        next_conversion: Optional[AnyConversion],
    ) -> SerializationMethod:
        with context_setter(self) as setter:
            if conversion.fall_back_on_any is not None:
                setter._fall_back_on_any = conversion.fall_back_on_any
            if conversion.exclude_unset is not None:
                setter._exclude_unset = conversion.exclude_unset
            serialize_conv = self.visit_with_conv(
                conversion.target, sub_conversion(conversion, next_conversion))

        converter = cast(Converter, conversion.converter)
        if converter is identity:
            method = serialize_conv
        elif serialize_conv is identity:
            return converter
        else:

            def method(obj: Any) -> Any:
                return serialize_conv(converter(obj))

        return self._wrap(get_origin_or_type(tp), method)
コード例 #5
0
 def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Iterator[str]:
     for field in fields:
         if field.flattened:
             yield from get_deserialization_flattened_aliases(
                 get_origin_or_type(tp), field, self.default_conversion
             )
         elif not field.is_aggregate:
             yield field.alias
コード例 #6
0
 def visit(self, tp: AnyType) -> Result:
     if not is_convertible(tp):
         return self.visit_conversion(tp, None, False, self._conversions)
     dynamic, conversion = self._has_conversion(tp, self._conversions)
     if not dynamic:
         _, conversion = self._has_conversion(
             tp,
             self.default_conversion(get_origin_or_type(tp))  # type: ignore
         )
     next_conversion = None
     if not dynamic and is_subclass(tp, Collection):
         next_conversion = self._conversions
     return self.visit_conversion(tp, conversion, dynamic, next_conversion)
コード例 #7
0
 def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> JsonSchema:
     cls = get_origin_or_type(tp)
     flattened_schemas: List[JsonSchema] = []
     pattern_properties = {}
     additional_properties: Union[bool,
                                  JsonSchema] = self.additional_properties
     properties = {}
     required = []
     for field in fields:
         if field.flattened:
             self._check_flattened_schema(cls, field)
             flattened_schemas.append(self.visit_field(field))
         elif field.pattern_properties is not None:
             if field.pattern_properties is ...:
                 pattern = infer_pattern(field.type,
                                         self.default_conversion)
             else:
                 assert isinstance(field.pattern_properties, Pattern)
                 pattern = field.pattern_properties
             pattern_properties[pattern] = self._properties_schema(field)
         elif field.additional_properties:
             additional_properties = self._properties_schema(field)
         else:
             alias = self.aliaser(field.alias)
             if is_typed_dict(cls):
                 is_required = field.required
             else:
                 is_required = self._field_required(field)
             properties[alias] = self.visit_field(field, is_required)
             if is_required:
                 required.append(alias)
     alias_by_names = {f.name: f.alias for f in fields}.__getitem__
     dependent_required = get_dependent_required(cls)
     result = json_schema(
         type=JsonType.OBJECT,
         properties=properties,
         required=required,
         additionalProperties=additional_properties,
         patternProperties=pattern_properties,
         dependentRequired=OrderedDict(
             (alias_by_names(f),
              sorted(map(alias_by_names, dependent_required[f])))
             for f in sorted(dependent_required, key=alias_by_names)),
     )
     if flattened_schemas:
         result = json_schema(
             type=JsonType.OBJECT,
             allOf=[result, *flattened_schemas],
             unevaluatedProperties=False,
         )
     return result
コード例 #8
0
    def unsupported(self, tp: AnyType) -> Result:
        from apischema import settings

        origin = get_origin_or_type(tp)
        if isinstance(origin, type):
            fields = settings.default_object_fields(origin)
            if fields is not None:
                if get_args(tp):
                    sub = dict(zip(get_parameters(origin), get_args(tp)))
                    fields = [
                        replace(f, type=substitute_type_vars(f.type, sub))
                        for f in fields
                    ]
                return self._object(origin, fields)
        return super().unsupported(tp)
コード例 #9
0
def _get_methods(
    tp: AnyType, all_methods: Mapping[Type, Mapping[str, S]]
) -> Mapping[str, Tuple[S, Mapping[str, AnyType]]]:
    result = {}
    for base in reversed(generic_mro(tp)):
        for name, method in all_methods[get_origin_or_type(base)].items():
            result[name] = (method, method.types(base))
    if has_model_origin(tp):
        origin = get_model_origin(tp)
        if get_args(tp):
            substitution = dict(
                zip(get_parameters(get_origin(tp)), get_args(tp)))
            origin = substitute_type_vars(origin, substitution)
        result.update(_get_methods(origin, all_methods))
    return result
コード例 #10
0
 def visit_conversion(
     self,
     tp: AnyType,
     conversion: Optional[Conv],
     dynamic: bool,
     next_conversion: Optional[AnyConversion] = None,
 ) -> JsonSchema:
     schemas = []
     if not dynamic:
         for ref_tp in self.resolve_conversion(tp):
             ref_schema = self.ref_schema(get_type_name(ref_tp).json_schema)
             if ref_schema is not None:
                 return ref_schema
         if get_args(tp):
             schemas.append(get_schema(get_origin_or_type(tp)))
         schemas.append(get_schema(tp))
     result = super().visit_conversion(tp, conversion, dynamic,
                                       next_conversion)
     return reduce(full_schema, schemas, result)
コード例 #11
0
 def _object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
     fields = [f for f in fields if not self._skip_field(f)]
     aliaser = get_class_aliaser(get_origin_or_type(tp))
     if aliaser is not None:
         fields = [_override_alias(f, aliaser) for f in fields]
     return self.object(tp, fields)
コード例 #12
0
 def _object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
     fields = [field for field in fields if SKIP_METADATA not in field.metadata]
     aliaser = get_class_aliaser(get_origin_or_type(tp))
     if aliaser is not None:
         fields = [_override_alias(f, aliaser) for f in fields]
     return self.object(tp, fields)
コード例 #13
0
ファイル: schema.py プロジェクト: callumforrester/apischema
    def object(
        self,
        tp: AnyType,
        fields: Sequence[ObjectField],
        resolvers: Sequence[ResolverField] = (),
    ) -> TypeFactory[graphql.GraphQLOutputType]:
        cls = get_origin_or_type(tp)
        all_fields = {
            f.alias: self._field(f)
            for f in fields if not f.is_aggregate
        }
        name_by_aliases = {f.alias: f.name for f in fields}
        all_fields.update({r.alias: self._resolver(r) for r in resolvers})
        name_by_aliases.update(
            {r.alias: r.resolver.func.__name__
             for r in resolvers})
        for alias, (resolver, types) in get_resolvers(tp).items():
            resolver_field = ResolverField(
                alias,
                resolver,
                types,
                resolver.parameters,
                resolver.parameters_metadata,
            )
            all_fields[alias] = self._resolver(resolver_field)
            name_by_aliases[alias] = resolver.func.__name__
        sorted_fields = sort_by_annotations_position(
            cls, all_fields, name_by_aliases.__getitem__)
        visited_fields = OrderedDict(
            (self.aliaser(a), all_fields[a]) for a in sorted_fields)
        flattened_types = {
            f.name: self._visit_flattened(f)
            for f in fields if f.flattened
        }

        def field_thunk() -> graphql.GraphQLFieldMap:
            return merge_fields(cls, visited_fields, flattened_types)

        interfaces = list(map(self.visit, get_interfaces(cls)))
        interface_thunk = None
        if interfaces:

            def interface_thunk() -> Collection[graphql.GraphQLInterfaceType]:
                result = {
                    cast(graphql.GraphQLInterfaceType, i.raw_type)
                    for i in interfaces
                }
                for flattened_factory in flattened_types.values():
                    flattened = cast(
                        Union[graphql.GraphQLObjectType,
                              graphql.GraphQLInterfaceType],
                        flattened_factory.raw_type,
                    )
                    result.update(flattened.interfaces)
                return sorted(result, key=lambda i: i.name)

        def factory(
            name: Optional[str], description: Optional[str]
        ) -> Union[graphql.GraphQLObjectType, graphql.GraphQLInterfaceType]:
            name = unwrap_name(name, cls)
            if is_interface(cls):
                return graphql.GraphQLInterfaceType(name,
                                                    field_thunk,
                                                    interface_thunk,
                                                    description=description)
            else:
                return graphql.GraphQLObjectType(
                    name,
                    field_thunk,
                    interface_thunk,
                    is_type_of=lambda obj, _: isinstance(obj, cls),
                    description=description,
                )

        return TypeFactory(factory)
コード例 #14
0
        def factory(constraints: Optional[Constraints],
                    validators: Sequence[Validator]) -> DeserializationMethod:
            cls = get_origin_or_type(tp)
            normal_fields: List[Tuple[str, str, DeserializationMethod,
                                      Required, FallBakOnDefault]] = []
            flattened_fields: List[Tuple[str, AbstractSet[str],
                                         DeserializationMethod,
                                         FallBakOnDefault]] = []
            pattern_fields: List[Tuple[str, Pattern, DeserializationMethod,
                                       FallBakOnDefault]] = []
            additional_field: Optional[Tuple[str, DeserializationMethod,
                                             FallBakOnDefault]] = None
            post_init_modified = {
                field.name
                for field in fields if field.post_init
            }
            alias_by_name = {
                field.name: self.aliaser(field.alias)
                for field in fields
            }
            requiring: Dict[str, Set[str]] = defaultdict(set)
            for f, reqs in get_dependent_required(cls).items():
                for req in reqs:
                    requiring[req].add(alias_by_name[f])
            init_defaults = [(f.name, f.default_factory) for f in fields
                             if f.kind == FieldKind.WRITE_ONLY]
            for field, field_factory in zip(fields, field_factories):
                deserialize_field = field_factory.method
                field_fall_back_on_default = (field.fall_back_on_default
                                              or fall_back_on_default)
                if field.flattened:
                    flattened_aliases = get_deserialization_flattened_aliases(
                        cls, field, self.default_conversion)
                    flattened_fields.append((
                        field.name,
                        set(map(self.aliaser, flattened_aliases)),
                        deserialize_field,
                        field_fall_back_on_default,
                    ))
                elif field.pattern_properties is ...:
                    pattern_fields.append((
                        field.name,
                        infer_pattern(field.type, self.default_conversion),
                        deserialize_field,
                        field_fall_back_on_default,
                    ))
                elif field.pattern_properties is not None:
                    assert isinstance(field.pattern_properties, Pattern)
                    pattern_fields.append((
                        field.name,
                        field.pattern_properties,
                        deserialize_field,
                        field_fall_back_on_default,
                    ))
                elif field.additional_properties:
                    additional_field = (
                        field.name,
                        deserialize_field,
                        field_fall_back_on_default,
                    )
                else:
                    normal_fields.append((
                        field.name,
                        self.aliaser(field.alias),
                        deserialize_field,
                        field.required or requiring[field.name],
                        field_fall_back_on_default,
                    ))
            has_aggregate_field = (flattened_fields or pattern_fields
                                   or (additional_field is not None))
            constraint_errors = get_constraint_errors(constraints, dict)
            aliaser = self.aliaser

            def method(data: Any) -> Any:
                if not isinstance(data, dict):
                    raise bad_type(data, dict)
                values: Dict[str, Any] = {}
                aliases: List[str] = []
                errors = list(
                    constraint_errors(data)) if constraint_errors else []
                field_errors: Dict[ErrorKey, ValidationError] = OrderedDict()
                for (
                        name,
                        alias,
                        field_method,
                        required,
                        fall_back_on_default,
                ) in normal_fields:
                    if alias in data:
                        aliases.append(alias)
                        try:
                            values[name] = field_method(data[alias])
                        except ValidationError as err:
                            if not fall_back_on_default:
                                field_errors[alias] = err
                    elif not required:
                        pass
                    elif required is True:
                        field_errors[alias] = MISSING_PROPERTY
                    else:
                        assert isinstance(required, AbstractSet)
                        requiring = required & data.keys()
                        if requiring:
                            msg = f"missing property (required by {sorted(requiring)})"
                            field_errors[alias] = ValidationError([msg])
                if has_aggregate_field:
                    for (
                            name,
                            flattened_alias,
                            field_method,
                            fall_back_on_default,
                    ) in flattened_fields:

                        flattened = {
                            alias: data[alias]
                            for alias in flattened_alias if alias in data
                        }
                        aliases.extend(flattened)
                        try:
                            values[name] = field_method(flattened)
                        except ValidationError as err:
                            if not fall_back_on_default:
                                errors.extend(err.messages)
                                field_errors.update(err.children)
                    if len(data) != len(aliases):
                        remain = data.keys() - set(aliases)
                    else:
                        remain = set()
                    for (
                            name,
                            pattern,
                            field_method,
                            fall_back_on_default,
                    ) in pattern_fields:
                        matched = {
                            key: data[key]
                            for key in remain if pattern.match(key)
                        }
                        remain -= matched.keys()
                        try:
                            values[name] = field_method(matched)
                        except ValidationError as err:
                            if not fall_back_on_default:
                                errors.extend(err.messages)
                                field_errors.update(err.children)
                    if additional_field is not None:
                        name, field_method, fall_back_on_default = additional_field
                        additional = {key: data[key] for key in remain}
                        try:
                            values[name] = field_method(additional)
                        except ValidationError as err:
                            if not fall_back_on_default:
                                errors.extend(err.messages)
                                field_errors.update(err.children)
                    elif remain and not additional_properties:
                        for key in remain:
                            field_errors[key] = UNEXPECTED_PROPERTY
                elif len(data) != len(aliases) and not additional_properties:
                    for key in data.keys() - set(aliases):
                        field_errors[key] = UNEXPECTED_PROPERTY

                validators2: Sequence[Validator]
                if validators:
                    init: Dict[str, Any] = {}
                    for name, default_factory in init_defaults:
                        if name in values:
                            init[name] = values[name]
                        elif name not in field_errors:
                            assert default_factory is not None
                            init[name] = default_factory()
                    # Don't keep validators when all dependencies are default
                    validators2 = [
                        v for v in validators
                        if v.dependencies & values.keys()
                    ]
                    if field_errors or errors:
                        error = ValidationError(errors, field_errors)
                        invalid_fields = field_errors.keys(
                        ) | post_init_modified
                        validators2 = [
                            v for v in validators2
                            if not v.dependencies & invalid_fields
                        ]
                        try:
                            validate(
                                ValidatorMock(cls, values),
                                validators2,
                                init,
                                aliaser=aliaser,
                            )
                        except ValidationError as err:
                            error = merge_errors(error, err)
                        raise error
                elif field_errors or errors:
                    raise ValidationError(errors, field_errors)
                else:
                    validators2, init = (
                    ), ...  # type: ignore # only for linter
                try:
                    res = cls(**values)
                except (AssertionError, ValidationError):
                    raise
                except TypeError as err:
                    if str(err).startswith("__init__() got"):
                        raise Unsupported(cls)
                    else:
                        raise ValidationError([str(err)])
                except Exception as err:
                    raise ValidationError([str(err)])
                if validators2:
                    validate(res, validators2, init, aliaser=aliaser)
                return res

            return method
コード例 #15
0
    def object(self, tp: AnyType,
               fields: Sequence[ObjectField]) -> SerializationMethod:
        with context_setter(self) as setter:
            setter._allow_undefined = True
            normal_fields, aggregate_fields = [], []
            for field in fields:
                serialize_field = self.visit_with_conv(field.type,
                                                       field.serialization)
                if field.is_aggregate:
                    aggregate_fields.append((field.name, serialize_field))
                else:
                    normal_fields.append(
                        (field.name, self.aliaser(field.alias),
                         serialize_field))
            serialized_methods = [(
                self.aliaser(name),
                method.func,
                self.visit_with_conv(types["return"], method.conversion),
            ) for name, (method, types) in get_serialized_methods(tp).items()]
        exclude_unset = self._exclude_unset

        def method(
            obj: Any,
            attr_getter=getattr,
            normal_fields=normal_fields,
            aggregate_fields=aggregate_fields,
        ) -> Any:
            result = {}
            # aggregate before normal fields to avoid overloading
            for name, field_method in aggregate_fields:
                attr = attr_getter(obj, name)
                result.update(field_method(attr))
            for name, alias, field_method in normal_fields:
                attr = attr_getter(obj, name)
                if attr is not Undefined:
                    result[alias] = field_method(attr)
            for alias, func, method in serialized_methods:
                res = func(obj)
                if res is not Undefined:
                    result[alias] = method(res)
            return result

        cls = get_origin_or_type(tp)
        if is_typed_dict(cls):
            cls, exclude_unset = Mapping, False
            wrapped_attr_getter = method

            def method(
                obj: Any,
                attr_getter=getattr,
                normal_fields=normal_fields,
                aggregate_fields=aggregate_fields,
            ) -> Any:
                return wrapped_attr_getter(obj,
                                           type(obj).__getitem__,
                                           normal_fields, aggregate_fields)

        if exclude_unset:
            wrapped_exclude_unset = method

            def method(
                obj: Any,
                attr_getter=getattr,
                normal_fields=normal_fields,
                aggregate_fields=aggregate_fields,
            ) -> Any:
                if hasattr(obj, FIELDS_SET_ATTR):
                    fields_set_ = fields_set(obj)
                    normal_fields = [(name, alias, method)
                                     for (name, alias, method) in normal_fields
                                     if name in fields_set_]
                    aggregate_fields = [(name, method)
                                        for (name, method) in aggregate_fields
                                        if name in fields_set_]
                return wrapped_exclude_unset(obj, attr_getter, normal_fields,
                                             aggregate_fields)

        return self._wrap_type_check(cls, method)
コード例 #16
0
    def object(self, tp: AnyType,
               fields: Sequence[ObjectField]) -> SerializationMethod:
        cls = get_origin_or_type(tp)
        identity_fields, normal_fields, skipped_if_fields, aggregate_fields = (
            [],
            [],
            [],
            [],
        )
        for field in fields:
            serialize_field = self.visit_with_conv(field.type,
                                                   field.serialization)
            alias = self.aliaser(field.alias)
            if field.is_aggregate:
                aggregate_fields.append(
                    (field.name, serialize_field, field.skip_if))
            elif field.skip_if is not None:
                skipped_if_fields.append(
                    (field.name, alias, serialize_field, field.skip_if))
            elif serialize_field is identity:
                identity_fields.append((field.name, alias))
            else:
                normal_fields.append((field.name, alias, serialize_field))
        serialized_methods = [(
            self.aliaser(name),
            serialized.func,
            self.visit_with_conv(types["return"], serialized.conversion),
        ) for name, (serialized, types) in get_serialized_methods(tp).items()]
        # Preallocate keys, notably to keep their order
        make_result = dict.fromkeys(
            self.aliaser(f.alias) for f in fields if not f.is_aggregate).copy
        method: Callable
        if not is_typed_dict(cls):

            def method(
                obj: Any,
                aggregate_fields=tuple(aggregate_fields),
                identity_fields=tuple(identity_fields),
                normal_fields=tuple(normal_fields),
                skipped_if_fields=tuple(skipped_if_fields),
                make_result=make_result,
            ) -> Any:
                result = make_result()
                # aggregate before normal fields to avoid overloading
                for name, field_method, skip_if in aggregate_fields:
                    attr = getattr(obj, name)
                    if skip_if is None or not skip_if(attr):
                        result.update(field_method(attr))
                for name, alias, field_method, skip_if in skipped_if_fields:
                    attr = getattr(obj, name)
                    if skip_if(attr):
                        result.pop(alias, ...)
                    else:
                        result[alias] = field_method(attr)
                for name, alias in identity_fields:
                    result[alias] = getattr(obj, name)
                for name, alias, field_method in normal_fields:
                    result[alias] = field_method(getattr(obj, name))
                return result

            if self._exclude_unset and support_fields_set(cls):
                wrapped_exclude_unset = method

                def method(obj: Any) -> Any:
                    if hasattr(obj, FIELDS_SET_ATTR):
                        fields_set_ = fields_set(obj)
                        new_fields = [
                            [(name, *_) for name, *_ in fields
                             if name in fields_set_]  # type: ignore
                            for fields in [
                                aggregate_fields,
                                identity_fields,
                                normal_fields,
                                skipped_if_fields,
                            ]
                        ]
                        return wrapped_exclude_unset(obj, *new_fields,
                                                     lambda: {})
                    return wrapped_exclude_unset(obj)

        else:

            def method(obj: Mapping) -> dict:
                result = make_result()
                # aggregate before normal fields to avoid overloading
                for name, field_method, skip_if in aggregate_fields:
                    if name in obj and (skip_if is None
                                        or not skip_if(obj[name])):
                        result.update(field_method(obj[name]))
                for name, alias, field_method, skip_if in skipped_if_fields:
                    if name in obj and not skip_if(obj[name]):
                        result[alias] = field_method(obj[name])
                    else:
                        del result[alias]
                for name, alias in identity_fields:
                    if name in obj:
                        result[alias] = obj[name]
                    else:
                        del result[alias]
                for name, alias, field_method in normal_fields:
                    if name in obj:
                        result[alias] = field_method(obj[name])
                    else:
                        del result[alias]
                return result

            if self.additional_properties:
                wrapped_additional = method
                field_names = {f.name for f in fields if not f.is_aggregate}
                any_method = self.any()

                def method(obj: Mapping) -> Mapping:
                    result = wrapped_additional(obj)
                    for key, value in obj.items():
                        if key not in field_names:
                            result[key] = any_method(value)
                    return result

        if serialized_methods:
            wrapped_serialized = method

            def method(obj: Any) -> Any:
                result = wrapped_serialized(obj)
                for alias, func, serialized_method in serialized_methods:
                    res = func(obj)
                    if res is not Undefined:
                        result[alias] = serialized_method(res)
                return result

        return self._wrap(cls, method)
コード例 #17
0
def is_convertible(tp: AnyType) -> bool:
    origin = get_origin_or_type(tp)
    return is_new_type(tp) or (is_type(origin)
                               and origin not in INVALID_CONVERSION_TYPES)