예제 #1
0
    def __init__(
        self,
        nested: typing.List[typing.Type[fields.SchemaABC]],
        mode: typing.Literal["anyOf", "allOf"] = "anyOf",
        *,
        default: typing.Any = fields.missing_,
        only: typing.Optional[types.StrSequenceOrSet] = None,
        exclude: types.StrSequenceOrSet = (),
        many: bool = False,
        unknown: typing.Optional[str] = None,
        **kwargs,
    ):
        if mode != "anyOf":
            raise NotImplementedError("allOf is not yet implemented.")

        context = getattr(self.parent, "context", {})
        context.update(kwargs.get("context", {}))

        self.nested = []
        schema_inst: Schema
        for schema in nested:
            schema_inst = common.resolve_schema_instance(schema)
            schema_inst.context.update(context)
            self.nested.append(schema_inst)

        self.mode = mode
        self.only = only
        self.exclude = exclude
        self.many = many
        self.unknown = unknown
        super().__init__(default=default, metadata={"anyOf": nested}, **kwargs)
예제 #2
0
    def _load_schemas(self, scalar: Result, partial=None) -> Result:
        rv = {}
        error_store = ErrorStore()
        value = dict(scalar)
        for schema in self._nested_schemas():
            schema_inst = common.resolve_schema_instance(schema)
            try:
                loaded = schema_inst.load(
                    value,
                    unknown=self.unknown,
                    partial=partial,
                )
                if not self.merged:
                    return loaded

                for key in schema_inst.declared_fields:
                    if key in value:
                        del value[key]
            except ValidationError as exc:
                for key in schema_inst.declared_fields:
                    if key in value:
                        del value[key]
                error_store.store_error({exc.field_name: exc.messages})
                continue

            if self.merged:
                rv.update(loaded)

        if error_store.errors:
            raise ValidationError(error_store.errors)
        return rv
예제 #3
0
    def load(self, data, *, many=None, partial=None, unknown=None):
        if self._has_processors(PRE_LOAD):
            data = self._invoke_load_processors(PRE_LOAD,
                                                data,
                                                many=many,
                                                original_data=data,
                                                partial=partial)

        if not isinstance(data, dict):
            raise ValidationError(f"Data type is invalid: {data}",
                                  field_name="_schema")

        try:
            schema = common.resolve_schema_instance(self.value_type)
            result = self._convert_with_schema(data, schema_func=schema.load)
        except ValueError:
            result = self._serialize_field(
                data, field=self.value_type[0])  # type: ignore[index]

        if self._has_processors(POST_LOAD):
            result = self._invoke_load_processors(
                POST_LOAD,
                result,
                many=many,
                original_data=data,
                partial=partial,
            )

        return result
예제 #4
0
def custom_name_resolver(schema):
    """
    Creates names for Marshmallow schemas in documentation.

    In case a schema is created using partial=`True`, `Partial-`
    will be added in front of the its name.

    In case a schema name ends with `Schema`, the `Schema` part
    is removed from the name.

    Adapted from https://github.com/marshmallow-code/apispec/pull/476/

    :param schema: Schema to name
    :type schema: `marshmallow.Schema`
    :return: The documented name for the schema
    :rtype: str
    """
    # Get an instance of the schema
    schema_instance = common.resolve_schema_instance(schema)
    if schema_instance.partial:
        prefix = "Patch-"
    elif schema_instance.only:
        prefix = "Partial-"
    else:
        prefix = ""

    # Get the class of the instance
    schema_cls = common.resolve_schema_cls(schema)
    name = prefix + schema_cls.__name__

    if name.endswith("Schema"):
        return name[:-6] or name
    return name
예제 #5
0
    def _dump_schemas(self, scalar: Result) -> list[Result]:
        rv = []
        error_store = ErrorStore()
        value = dict(scalar)
        for schema in self._nested_schemas():
            schema_inst = common.resolve_schema_instance(schema)
            try:
                dumped = schema_inst.dump(value, many=False)
                if not self.merged:
                    return dumped
                loaded = schema_inst.load(dumped)
                # We check what could actually pass through the load() call, because some
                # schemas validate keys without having them defined in their _declared_fields.
                for key in loaded.keys():
                    if key in value:
                        del value[key]
            except ValidationError as exc:
                # When we encounter an error, we can't do anything besides remove the keys, which
                # we know about.
                for key in schema_inst.declared_fields:
                    if key in value:
                        del value[key]
                error_store.store_error({exc.field_name: exc.messages})
                continue

            if not isinstance(schema_inst, MultiNested.ValidateOnDump):
                rv.append(dumped)

        if error_store.errors:
            raise ValidationError(error_store.errors)

        return rv
예제 #6
0
    def load(self, data, *, many=None, partial=None, unknown=None):
        if self._has_processors(PRE_LOAD):
            data = self._invoke_load_processors(
                PRE_LOAD, data, many=many, original_data=data, partial=partial
            )

        if not isinstance(data, dict):
            raise ValidationError(f"Data type is invalid: {data}", field_name="_schema")

        if isinstance(self.value_type, FieldWrapper):
            result = self._serialize_field(data, field=self.value_type.field)
        elif isinstance(self.value_type, BaseSchema) or (
            isinstance(self.value_type, type) and issubclass(self.value_type, SchemaABC)
        ):
            schema = common.resolve_schema_instance(self.value_type)
            result = self._convert_with_schema(data, schema_func=schema.load)
        else:
            raise ValidationError(
                f"Data type is not known: {type(self.value_type)} {self.value_type}"
            )

        if self._has_processors(POST_LOAD):
            result = self._invoke_load_processors(
                POST_LOAD,
                result,
                many=many,
                original_data=data,
                partial=partial,
            )

        return result
예제 #7
0
    def definition_helper(self, name, schema, **kwargs):
        """Definition helper that allows using a marshmallow
        :class:`Schema <marshmallow.Schema>` to provide OpenAPI metadata.

        :param str name: Name to use for definition.
        :param type|Schema schema: A marshmallow Schema class or instance.
        """
        schema_cls = resolve_schema_cls(schema)
        schema_instance = resolve_schema_instance(schema)

        # Store registered refs, keyed by Schema class
        self.references[schema_cls] = name

        if hasattr(schema_instance, 'fields'):
            fields = schema_instance.fields
        elif hasattr(schema_instance, '_declared_fields'):
            fields = schema_instance._declared_fields
        else:
            raise ValueError(
                "{0!r} doesn't have either `fields` or `_declared_fields`".
                format(schema_instance))

        ret = super().definition_helper(name, schema_instance, **kwargs)

        for field_name, field_obj in fields.items():
            if isinstance(field_obj, Hyperlinks):
                ret['properties'][field_name]['properties'] = _rapply(
                    field_obj.schema, self.openapi.field2property, name=name)

        return ret
예제 #8
0
파일: base.py 프로젝트: tribe29/checkmk
    def dump(self, obj: typing.Any, *, many=None):
        if self._has_processors(PRE_DUMP):
            obj = self._invoke_dump_processors(PRE_DUMP,
                                               obj,
                                               many=many,
                                               original_data=obj)

        if isinstance(self.value_type, FieldWrapper):
            result = self._deserialize_field(obj, field=self.value_type.field)
        elif isinstance(
                self.value_type,
                BaseSchema) or (isinstance(self.value_type, type)
                                and issubclass(self.value_type, SchemaABC)):
            schema = common.resolve_schema_instance(self.value_type)
            result = self._convert_with_schema(obj, schema_func=schema.dump)
        else:
            raise ValidationError(f"Data type is not known: {type(obj)}")

        if self._has_processors(POST_DUMP):
            result = self._invoke_dump_processors(POST_DUMP,
                                                  result,
                                                  many=many,
                                                  original_data=obj)

        return result
예제 #9
0
    def _deserialize(
        self,
        value: typing.Any,
        attr: typing.Optional[str],
        data: typing.Optional[typing.Mapping[str, typing.Any]],
        **kwargs,
    ):
        error_store = ErrorStore()
        if self.many:
            result = []
            if utils.is_collection(value):
                for collection_entry in value:
                    result.append(self._check_schemas(collection_entry))
                return result

            raise self.make_error("type",
                                  input=value,
                                  type=value.__class__.__name__)

        for schema in self.nested:
            try:
                return common.resolve_schema_instance(schema).load(
                    value, unknown=self.unknown)
            except ValidationError as exc:
                error_store.store_error(exc.messages,
                                        field_name=exc.field_name)

        raise ValidationError(error_store.errors, data=value)
예제 #10
0
def resolver(schema):
    schema_instance = common.resolve_schema_instance(schema)
    prefix = "Partial-" if schema_instance.partial else ""
    schema_cls = common.resolve_schema_cls(schema)
    name = prefix + schema_cls.__name__
    if name.endswith("Schema"):
        return name[:-6] or name
    return name
예제 #11
0
 def dump(self, obj, many=None, update_fields=True, **kwargs):
     schema = common.resolve_schema_instance(self.value_type)
     result = {}
     for entry in obj:
         part = schema.dump(entry).data
         result[part[self.key_name]] = part
         if not self.keep_key:
             del part[self.key_name]
     return MarshalResult(result, [])
예제 #12
0
 def dump(self, obj: Any, *, many=None):
     schema = common.resolve_schema_instance(self.value_type)
     result = {}
     for entry in obj:
         part = schema.dump(entry)
         result[part[self.key_name]] = part
         if not self.keep_key:
             del part[self.key_name]
     return result
예제 #13
0
def resolver(schema):
    """Default schema name resolver function that strips 'Schema' from the end of the class name."""
    schema_cls = resolve_schema_cls(schema)
    name = schema_cls.__name__
    if name.endswith("Schema"):
        name = name[:-6] or name
    schema_inst = resolve_schema_instance(schema)
    if schema_inst.partial:
        return f'{name}_partial'
    if schema_inst.only:
        return f'{name}.{"-".join(schema_inst.only)}'
    return name
예제 #14
0
    def load(self, data, many=None, partial=None):
        if not isinstance(data, dict):
            return UnmarshalResult({}, {'_schema': 'Invalid data type: %s' % data})

        schema = common.resolve_schema_instance(self.value_type)
        res = []
        for key, value in data.items():
            payload = value.copy()
            payload[self.key_name] = key
            result = schema.load(payload)
            res.append(result.data)

        return UnmarshalResult(res, [])
예제 #15
0
    def load(self, data, *, many=None, partial=None, unknown=None):
        if not isinstance(data, dict):
            raise ValidationError({'_schema': f'Data type is invalid: {data}'})

        schema = common.resolve_schema_instance(self.value_type)
        res = []
        for key, value in data.items():
            payload = value.copy()
            payload[self.key_name] = key
            result = schema.load(payload)
            res.append(result)

        return res
예제 #16
0
    def _check_schemas(self, scalar, partial=None):
        error_store = ErrorStore()
        for schema in self.nested:
            try:
                return common.resolve_schema_instance(schema).load(
                    scalar,
                    unknown=self.unknown,
                    partial=partial,
                )
            except ValidationError as exc:
                error_store.store_error(exc.messages,
                                        field_name=exc.field_name)

        raise ValidationError(error_store.errors, data=scalar)
예제 #17
0
def schema_name_resolver(schema):
    cls = resolve_schema_cls(schema)
    instance = resolve_schema_instance(schema)
    name = cls.__name__
    if not cls.opts.register:
        # Unregistered schemas are put inline.
        return False
    if instance.only:
        # If schema includes only select fields, treat it as nonce
        return False
    if name.endswith("Schema"):
        return name[:-6] or name
    if instance.partial:
        name = "Partial" + name
    return name
예제 #18
0
파일: plugins.py 프로젝트: petrows/checkmk
 def dump(self, obj: Any, *, many=None):
     schema = common.resolve_schema_instance(self.value_type)
     result = {}
     for entry in obj:
         part = schema.dump(entry)
         # HACK. marshmallow_oneofschema returns errors instead of raising them. :-(
         # See https://github.com/marshmallow-code/marshmallow-oneofschema/issues/48
         is_error_return = (isinstance(part, tuple) and len(part) == 2 and part[0] is None and
                            isinstance(part[1], dict) and '_schema' in part[1])
         if is_error_return:
             raise ValidationError(part[1]['_schema'])
         result[part[self.key_name]] = part
         if not self.keep_key:
             del part[self.key_name]
     return result
예제 #19
0
    def schema2jsonschema(self, schema):
        if self.openapi_version.major < 3 or not is_value_typed_dict(schema):
            return super().schema2jsonschema(schema)

        schema_type = schema.value_type
        schema_instance = common.resolve_schema_instance(schema_type)
        schema_key = common.make_schema_key(schema_instance)
        if schema_key not in self.refs:
            component_name = self.schema_name_resolver(schema_type)
            self.spec.components.schema(component_name, schema=schema_instance)

        ref_dict = self.get_ref_dict(schema_instance)

        return {
            u'type': u'object',
            u'additionalProperties': ref_dict,
        }
예제 #20
0
 def _add_examples(self, ref_schema, endpoint_schema, example):
     def add_to_endpoint_or_ref():
         if add_to_refs:
             self.spec.components._schemas[name]["example"] = example
         else:
             endpoint_schema[0]['schema']['allOf'] = [endpoint_schema[0]['schema'].pop('$ref')]
             endpoint_schema[0]['schema']["example"] = example
     if not example:
         return
     schema_instance = common.resolve_schema_instance(ref_schema)
     name = self.plugin.converter.schema_name_resolver(schema_instance)
     add_to_refs = example.pop('add_to_refs')
     if self.spec.components.openapi_version.major < 3:
         if name and name in self.spec.components._schemas:
             add_to_endpoint_or_ref()
     else:
         add_to_endpoint_or_ref()
예제 #21
0
    def _serialize(
        self,
        value: typing.Any,
        attr: str,
        obj: typing.Any,
        **kwargs,
    ):
        if value is None:
            return None

        error_store = ErrorStore()
        for schema in self.nested:
            try:
                return common.resolve_schema_instance(schema).dump(
                    value, many=self.many)
            except ValidationError as exc:
                error_store.store_error(exc.messages,
                                        field_name=exc.field_name)

        raise ValidationError(error_store.errors, data=value)
예제 #22
0
    def dump(self, obj: typing.Any, *, many=None):
        if self._has_processors(PRE_DUMP):
            obj = self._invoke_dump_processors(PRE_DUMP,
                                               obj,
                                               many=many,
                                               original_data=obj)

        try:
            schema = common.resolve_schema_instance(self.value_type)
            result = self._convert_with_schema(obj, schema_func=schema.dump)
        except ValueError:
            result = self._deserialize_field(
                obj, field=self.value_type[0])  # type: ignore[index]

        if self._has_processors(POST_DUMP):
            result = self._invoke_dump_processors(POST_DUMP,
                                                  result,
                                                  many=many,
                                                  original_data=obj)

        return result
예제 #23
0
    def schema2jsonschema(self, schema):
        if self.openapi_version.major < 3 or not is_oneof(schema):
            return super(OneofOpenAPIConverter, self).schema2jsonschema(schema)
        mapping = {}
        oneof = []
        for name, type_schema in schema.type_schemas.items():
            schema_instance = common.resolve_schema_instance(type_schema)
            schema_key = common.make_schema_key(schema_instance)
            if schema_key not in self.refs:
                component_name = self.schema_name_resolver(type_schema) or name
                self.spec.components.schema(component_name, schema=type_schema)
            ref_dict = self.get_ref_dict(schema_instance)
            mapping.update({name: ref_dict['$ref']})
            oneof.append(ref_dict)

        return {
            'oneOf': oneof,
            'discriminator': {
                'propertyName': schema.type_field,
                'mapping': mapping
            }
        }
예제 #24
0
파일: openapi.py 프로젝트: m3rlinux/checkmk
    def schema2jsonschema(self, schema):
        if not is_value_typed_dict(schema):
            return super().schema2jsonschema(schema)

        if isinstance(schema.value_type, FieldWrapper):
            properties = field_properties(schema.value_type.field)
        elif isinstance(schema.value_type, base.SchemaABC) or (
                isinstance(schema.value_type, type)
                and issubclass(schema.value_type, base.SchemaABC)):
            schema_instance = common.resolve_schema_instance(schema.value_type)
            schema_key = common.make_schema_key(schema_instance)
            if schema_key not in self.refs:
                component_name = self.schema_name_resolver(schema.value_type)
                self.spec.components.schema(component_name,
                                            schema=schema_instance)
            properties = self.get_ref_dict(schema_instance)
        else:
            raise RuntimeError(f"Unsupported value_type: {schema.value_type}")

        return {
            "type": "object",
            "additionalProperties": properties,
        }
예제 #25
0
파일: openapi.py 프로젝트: PLUTEX/checkmk
    def schema2jsonschema(self, schema):
        if not is_value_typed_dict(schema):
            return super().schema2jsonschema(schema)

        try:
            schema_instance = common.resolve_schema_instance(schema.value_type)
        except ValueError:
            schema_instance = None

        if schema_instance is None:
            properties = field_properties(schema.value_type[0])
        else:
            schema_key = common.make_schema_key(schema_instance)
            if schema_key not in self.refs:
                component_name = self.schema_name_resolver(schema.value_type)
                self.spec.components.schema(component_name,
                                            schema=schema_instance)
            properties = self.get_ref_dict(schema_instance)

        return {
            "type": "object",
            "additionalProperties": properties,
        }
예제 #26
0
    def resolve_nested_schema(self, schema):
        try:
            instance = resolve_schema_instance(schema)
        except Exception:
            # Let base class handle it
            instance = None

        if not instance or not isinstance(instance, OneOfSchema):
            return super().resolve_nested_schema(schema)

        mapping = {}
        refs = []
        for type_name, schema_cls in instance.type_schemas.items():
            ref = self.resolve_nested_schema(schema_cls)
            refs.append(ref)
            mapping[type_name] = ref['$ref']

        return {
            'discriminator': {
                'mapping': mapping,
                'propertyName': instance.type_field,
            },
            'oneOf': refs,
        }
예제 #27
0
    def __init__(
        self,
        nested: typing.List[typing.Type[fields.SchemaABC]],
        mode: typing.Literal["anyOf", "allOf"] = "anyOf",
        *,
        default: typing.Any = fields.missing_,
        only: typing.Optional[types.StrSequenceOrSet] = None,
        exclude: types.StrSequenceOrSet = (),
        many: bool = False,
        unknown: typing.Optional[str] = None,
        # In this loop we do the following:
        #  1) we try to dump all the keys of a model
        #  2) when the dump succeeds, we remove all the dumped keys from the source
        #  3) we go to the next model and goto 1
        #
        # For this we assume the schema is always symmetrical (i.e. a round trip is
        # idempotent) to get at the original keys. If this is not true, there may be bugs.
        merged: bool = False,
        **kwargs,
    ):
        if merged and many:
            raise NotImplementedError("merged=True with many=True is not supported.")

        if mode != "anyOf":
            raise NotImplementedError("allOf is not yet implemented.")

        metadata = kwargs.pop("metadata", {})
        context = getattr(self.parent, "context", {})
        context.update(metadata.get("context", {}))

        self._nested = []
        schema_inst: Schema
        for schema in nested:
            schema_inst = common.resolve_schema_instance(schema)
            schema_inst.context.update(context)
            self._nested.append(schema_inst)

        metadata["anyOf"] = self._nested

        # We need to check that the key names of all schemas are completely disjoint, because
        # we can't represent multiple schemas with the same key in merge-mode.
        if merged:
            set1: set[str] = set()
            for schema_inst in self._nested:
                keys = set(schema_inst.declared_fields.keys())
                if not set1.isdisjoint(keys):
                    wrong_keys = ", ".join(repr(key) for key in sorted(set1.intersection(keys)))
                    raise RuntimeError(
                        f"Schemas {self._nested} are not disjoint. "
                        f"Keys {wrong_keys} occur more than once."
                    )
                set1.update(keys)

        self.mode = mode
        self.only = only
        self.exclude = exclude
        self.many = many
        self.merged = merged
        # When we are merging, we don't want to have errors due to cross-schema validation.
        # When we operate in standard mode, we really want to know these errors.
        self.unknown = EXCLUDE if self.merged else RAISE
        super().__init__(default=default, metadata=metadata, **kwargs)