def __init__(self, *args, **kwargs):
        schema_kwargs = {"many": kwargs.get("many", False)}
        # Remove any kwargs that are only valid for marshmallow schemas
        for key in _schema_kwargs:
            if key in kwargs:
                schema_kwargs[key] = kwargs.pop(key)

        super(Schema, self).__init__(*args, **kwargs)
        MarshmallowSchema.__init__(self, **schema_kwargs)
Example #2
0
def _base_request(request: Request, request_schema: Schema, response_schema: Schema,
                  method: Callable[[object, Optional[str]], object], is_paging: bool = False,
                  code=None, code_name: Optional[str] = None):
    """Request processing base method.

    :param request: Request object
    :param request_schema: RequestSchema Instance
    :param response_schema: ResponseSchema Instance
    :param method: method to be processed by the request
    :param is_paging: paging type request
    :param code: request identifier
    :param code_name: parameter name of request identifier
    :return:
    """
    if not request:
        return response.HttpResponseNotFound()
    request_param = {}
    if request.query_params:
        request_param.update(request.query_params.dict())
    if request.data:
        request_param.update(request.data)
    if code and code_name:
        request_param[code_name] = code
    request_obj, errors = request_schema.load(request_param)
    if errors:
        return response.HttpResponseBadRequest(errors)
    try:
        if is_paging:
            path = _parse_path(request)
            response_obj = method(request_obj, path)
        else:
            response_obj = method(request_obj)
    except RequestParameterException:
        return response.HttpResponseBadRequest()
    except ResourceNotFoundException:
        return response.HttpResponseNotFound()
    except Exception as e:
        return response.HttpResponseServerError()
    data, _ = response_schema.dump(response_obj)
    if request.method == 'POST':
        return Response(data, status.HTTP_201_CREATED)
    elif request.method == 'DELETE':
        return Response({}, status.HTTP_204_NO_CONTENT)
    else:
        return Response(data, status.HTTP_200_OK)
Example #3
0
    def get_parameters(
        self,
        req: falcon.Request,
        schema: marshmallow.Schema
    ):
        """ Decodes the body of incoming requests via a `marshmallow` schema
            and returns the result.

        Args:
            req (falcon.Request): The Falcon `Request` object.
            schema (marshmallow.Schema): The marshmallow schema instance that
                will be used to decode the body.
        """

        try:
            request_json = req.stream.read()
        except Exception as exc:
            msg_fmt = "Could not retrieve JSON body."
            self.logger.exception(msg_fmt)
            raise falcon.HTTPError(
                status=falcon.HTTP_400,
                title="InvalidRequest",
                description=msg_fmt,
            )

        try:
            if request_json:
                # If the JSON body came through as `bytes` it needs to be
                # decoded into a `str`.
                if isinstance(request_json, bytes):
                    request_json = request_json.decode("utf-8")
                # Decode the JSON body through the `marshmallow` schema.
                parameters = schema.loads(request_json).data
            else:
                parameters = {}
        except marshmallow.ValidationError as exc:
            msg_fmt = "Response body violates schema."
            self.logger.exception(msg_fmt)
            raise falcon.HTTPError(
                status=falcon.HTTP_422,
                title="Schema violation.",
                description=msg_fmt + " Exception: {0}".format(str(exc))
            )
        except Exception as exc:
            msg_fmt = "Could not decode JSON body '{}'.".format(request_json)
            self.logger.exception(msg_fmt)
            raise falcon.HTTPError(
                status=falcon.HTTP_400,
                title="InvalidRequest",
                description=msg_fmt + ". Exception: {0}".format(str(exc))
            )

        return parameters
Example #4
0
def test_delimited_tuple_custom_delimiter(web_request, parser):
    web_request.json = {"ids": "1|2"}
    schema_cls = Schema.from_dict(
        {"ids": fields.DelimitedTuple((fields.Int, fields.Int), delimiter="|")}
    )
    schema = schema_cls()

    parsed = parser.parse(schema, web_request)
    assert parsed["ids"] == (1, 2)

    data = schema.dump(parsed)
    assert data["ids"] == "1|2"
Example #5
0
 def _try_load(schema: Schema, response: Response) -> Any:
     if 200 <= response.status_code < 300:
         try:
             return schema.load(response.json())
         except JSONDecodeError as e:
             logger.exception(e.msg)
             return None
         except ValidationError as e:
             logger.exception(str(e.messages))
             return None
     else:
         return None
Example #6
0
    def build_response(self, schema: Schema, response):
        """
        Validate the given response against a given Schema.
        Return the response data as a serialized object according to the given Schema's fields.

        :param schema:
        :param response:
        :return:
        """
        # Validate that schema.
        # This is not normally done with responses, but I want to be strict about ensuring the schema is up-to-date
        validation_errors = schema.validate(response)

        if validation_errors:
            # Throw an exception here with all the errors.
            # This will be caught and handled by the 500 internal error
            raise exceptions.ValidationError(validation_errors)

        # Build schema object from response
        data = schema.dump(response)
        return data
Example #7
0
def test_delimited_list_custom_delimiter(web_request, parser):
    web_request.json = {"ids": "1|2|3"}
    schema_cls = Schema.from_dict(
        {"ids": fields.DelimitedList(fields.Int(), delimiter="|")}
    )
    schema = schema_cls()

    parsed = parser.parse(schema, web_request)
    assert parsed["ids"] == [1, 2, 3]

    data = schema.dump(parsed)
    assert data["ids"] == "1|2|3"
Example #8
0
class ProjectConfig(BaseConfig):
    schema = Schema.from_dict(
        {
            # CLI Arguments
            "name": fields.Str(
                missing=ProjectDefaults.PROJECT_NAME,
                help="Name of the root project folder",
            ),
            "wml_dir": fields.Str(
                missing=ProjectDefaults.WML_FOLDER,
                help="Folder to store windmill WML files",
            ),
            "dags_dir": fields.Str(
                missing=ProjectDefaults.DAGS_FOLDER,
                help="Folder to store generated YML DAG files",
            ),
            "operators_dir": fields.Str(
                missing=ProjectDefaults.OPERATORS_FOLDER,
                help="Folder to store custom operator files",
            ),
        }
    )

    def __init__(
        self,
        name: str,
        wml_dir: str,
        dags_dir: str,
        operators_dir: str,
        conf_file: str = ProjectDefaults.PROJECT_CONF,
        *args,
        **kwargs,
    ):
        """Handler for project file
        
        Args:
            ...Project.schema
            conf_file (str, optional): Default project config filename. Defaults to ProjectDefaults.PROJECT_CONF.
        """
        self.name = name
        self.wml_dir = wml_dir
        self.dags_dir = dags_dir
        self.operators_dir = operators_dir
        self.conf_file = conf_file

    @staticmethod
    def from_conf_file(filename):
        try:
            with open(filename, "r+") as f:
                return ProjectConfig(**yaml.load(f))
        except Exception as e:
            raise InitError("This directory is not a valid windmill project") from e
Example #9
0
class DocSchema(OpenAPISchema):
    """Verifiable doc schema."""
    class Meta:
        """Keep unknown values."""

        unknown = INCLUDE

    proof = fields.Nested(
        Schema.from_dict({
            "creator": fields.Str(required=False),
            "verificationMethod": fields.Str(required=False),
            "proofPurpose": fields.Str(required=False),
        }))
Example #10
0
def test_delimited_tuple_passed_invalid_type(web_request, parser):
    web_request.json = {"ids": 1}
    schema_cls = Schema.from_dict(
        {"ids": fields.DelimitedTuple((fields.Int, ))})
    schema = schema_cls()

    with pytest.raises(ValidationError) as excinfo:
        parser.parse(schema, web_request)
    assert excinfo.value.messages == {
        "json": {
            "ids": ["Not a valid delimited tuple."]
        }
    }
Example #11
0
    def json(cls, *, response: Response,
             schema: m.Schema) -> 'PaddockResponse[Any]':
        """
        Create a response whose body is deserialized lazily as JSON with the
        body() function according to the given schema.

        :param response: the response to wrap
        :param schema: the schema to use to deserialize the body
        :return: the wrapped response
        """
        text = response.text
        return PaddockResponse(response=response,
                               converter=lambda: schema.loads(text))
Example #12
0
    def dump(self, obj, *, many=None):
        dumped = Schema.dump(self, obj, many=many)
        # TODO This is hacky, but the other option I can think of is to generate a different schema
        #  depending on dump and load, which is even more hacky

        # The only problem is the catch all field, we can't statically create a schema for it
        # so we just update the dumped dict
        if many:
            for i, _obj in enumerate(obj):
                dumped[i].update(_handle_undefined_parameters_safe(cls=_obj, kvs={}, usage="dump"))
        else:
            dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={}, usage="dump"))
        return dumped
Example #13
0
def test_delimited_tuple_load_list_errors(web_request, parser):
    web_request.json = {"ids": [1, 2]}
    schema_cls = Schema.from_dict(
        {"ids": fields.DelimitedTuple((fields.Int, fields.Int))}
    )
    schema = schema_cls()

    with pytest.raises(ValidationError) as excinfo:
        parser.parse(schema, web_request)
    exc = excinfo.value
    assert isinstance(exc, ValidationError)
    errors = exc.args[0]
    assert errors["ids"] == ["Not a valid delimited tuple."]
Example #14
0
 def get_json_body(self, schema: Schema = None) -> dict:
     content_type = self.request.headers.get('Content-Type')
     if not content_type:
         return dict()
     if not content_type.startswith('application/json'):
         return dict()
     try:
         result = orjson.loads(self.request.body or b'{}')
     except orjson.JSONDecodeError:
         raise WidgetsParameterError
     if schema:
         return schema.load(result)
     return result
Example #15
0
def conditional_crypto_deserialize(object_dict, parent_object_dict):
    """Return the WebPush Crypto Schema if there's a data payload"""
    if parent_object_dict.get("body"):
        encoding = object_dict.get("content-encoding")
        # Validate the crypto headers appropriately
        if encoding == "aesgcm128":
            return WebPushCrypto01HeaderSchema()
        elif encoding == "aesgcm":
            return WebPushCrypto04HeaderSchema()
        else:
            return WebPushInvalidContentEncodingSchema()
    else:
        return Schema()
Example #16
0
def test_whitespace_stripping_parser_example(web_request):
    def _strip_whitespace(value):
        if isinstance(value, str):
            value = value.strip()
        elif isinstance(value, typing.Mapping):
            return {k: _strip_whitespace(value[k]) for k in value}
        elif isinstance(value, (list, tuple)):
            return type(value)(map(_strip_whitespace, value))
        return value

    class WhitspaceStrippingParser(MockRequestParser):
        def pre_load(self, location_data, *, schema, req, location):
            if location in ("query", "form"):
                ret = _strip_whitespace(location_data)
                return ret
            return location_data

    parser = WhitspaceStrippingParser()

    # mock data for query, form, and json
    web_request.form = web_request.query = web_request.json = {"value": " hello "}
    argmap = {"value": fields.Str()}

    # data gets through for 'json' just fine
    ret = parser.parse(argmap, web_request)
    assert ret == {"value": " hello "}

    # but for 'query' and 'form', the pre_load hook changes things
    for loc in ("query", "form"):
        ret = parser.parse(argmap, web_request, location=loc)
        assert ret == {"value": "hello"}

    # check that it applies in the case where the field is a list type
    # applied to an argument (logic for `tuple` is effectively the same)
    web_request.form = web_request.query = web_request.json = {
        "ids": [" 1", "3", " 4"],
        "values": [" foo  ", " bar"],
    }
    schema = Schema.from_dict(
        {"ids": fields.List(fields.Int), "values": fields.List(fields.Str)}
    )
    for loc in ("query", "form"):
        ret = parser.parse(schema, web_request, location=loc)
        assert ret == {"ids": [1, 3, 4], "values": ["foo", "bar"]}

    # json loading should also work even though the pre_load hook above
    # doesn't strip whitespace from JSON data
    #   - values=[" foo  ", ...]  will have whitespace preserved
    #   - ids=[" 1", ...]  will still parse okay because "  1" is valid for fields.Int
    ret = parser.parse(schema, web_request, location="json")
    assert ret == {"ids": [1, 3, 4], "values": [" foo  ", " bar"]}
Example #17
0
    def read_json_request(self, schema: Schema):
        """

        :param schema:
        :type schema: Schema descendant
        :return:
        """
        # Ensure body can be JSON decoded
        try:
            json_data = json.loads(self.request.body)
        except JSONDecodeError as e:
            self.set_status(self.STATUS_ERROR_EXTERNAL, reason=str(e))
            self.write_error()
            raise BaseApiError("Expected request body to be JSON. Received '{}'".format(self.request.body))

        request_validation_errors = schema.validate(json_data)
        if request_validation_errors:
            self.error_messages = request_validation_errors
            self.set_status(self.STATUS_ERROR_EXTERNAL, reason="Failed request schema validation")
            self.write_error()
            raise BaseApiError("Failed schema validation: {}".format(str(request_validation_errors)))

        return schema.dump(schema.load(json_data))
Example #18
0
def test_notebooks_field_invalid_keys(monkeypatch, key, message):
    """
    Test that NotebooksField raises a ValidationError if a notebook key is not
    a string, or has a disallowed value.
    """
    monkeypatch.setattr("pathlib.Path.exists", lambda self: True)
    notebooks = {key: {"filename": "NOTEBOOK1.ipynb"}}
    notebooks_field = NotebooksField()
    # Can't set context directly on a Field - must be set on the parent Schema
    notebooks_field._bind_to_schema(
        "notebooks", Schema(context={"inputs_dir": "DUMMY_INPUTS_DIR"}))
    with pytest.raises(ValidationError) as exc_info:
        deserialised_notebooks = notebooks_field.deserialize(notebooks)
    assert message in exc_info.value.messages[key]["key"]
Example #19
0
    def deserialize(json_string: str,
                    class_type: Type[T],
                    schema: Schema = None) -> T:
        if schema is not None:
            dictionary = schema.loads(json_data=json_string)
            if not isinstance(dictionary, list):
                return class_type(**dictionary)
            else:
                lstObject = list()
                for item in dictionary:
                    lstObject.append(class_type(**item))
                return lstObject

        return jsons.loads(json_string, class_type)
Example #20
0
def schema_from_base():
    endpoint_params = {
        name: app_containers.service_parameters.get('base').parameter(
            name).__class__.__name__
        for name in app_containers.service_parameters.get('base').parameters()
    }

    for service_name in app_containers.services.keys():
        service_params = get_endpoint_params(service_name)
        endpoint_params = {**endpoint_params, **service_params}
    return Schema.from_dict({
        name: MAP.get(parameter)
        for name, parameter in endpoint_params.items()
    })
Example #21
0
    def to_representation(self, instance):
        ret = {}
        try:
            for field in self.readable_fields:
                attr = getattr(instance, field)
                if isinstance(attr, peewee.Model):
                    attr = attr.id
                ret[field] = attr
        except ValueError as e:
            self._errors.append(e)

        schema = Schema.from_dict(self.readable_fields)()

        return schema.dump(ret)
Example #22
0
def process_schema(schema):
    """
    Handle a schema passed in as a view deriver, creating a nonce schema if a
    dictionary.

    """
    if schema is None:
        return None
    elif isinstance(schema, Schema):
        return schema
    elif isinstance(schema, dict):
        return Schema.from_dict(schema)()
    else:
        raise TypeError("Schema is of invalid type.")
Example #23
0
    def verify_data(self,
                    data: dict,
                    schema: Schema,
                    context: dict = None) -> dict:
        """
        使用schema实例验证数据,有错误时抛出ValidateError

        :param dict data: 需要验证的数据
        :param Schema schema: schema实例
        :param dict context: 传递给schema使用的额外数据,保存在schema的context属性中
        :return:
        """
        # 传递给schema使用的额外数据
        if context:
            schema.context = context
        try:
            data = schema.load(data)
        except ValidationError as e:
            # 合并多个验证器对于同一字段的相同错误
            for key in e.messages.keys():
                e.messages[key] = list(set(e.messages[key]))
            raise ValidateError(e.messages, replace=True)
        return data
Example #24
0
def gen_schemas(tables: List[Table],
                dialect: str,
                to_file: bool = True) -> Dict[str, Schema]:
    """生成各表对应的marshmallow的Schema及定义这些Schema的py文件

    Args:
        tables: sqlalchemy通过反射获取的表
        to_file: 是否生成py文件

    Returns:
        表名为键,相应Schema为值的字典
    """

    if to_file is True:
        f = open("schemas.py", "wb")
        f.write(
            ("# -*- coding: utf-8 -*-\n\n\n" +
             "from marshmallow import fields, Schema\n\n\n").encode("utf-8"))

    schemas = {}  # type: Dict[str, Schema]

    last_table_num = len(tables) - 1
    for i, table in enumerate(tables):
        schema_dic = {}  # type: Dict[str, fields.Field]
        result = analyse_table(table, dialect)
        if to_file is True:
            schema_name = "".join(
                [x.capitalize() for x in result["table"].split("_")])

            f.write(
                (f"class {schema_name}Schema(Schema):\n\n").encode("utf-8"))
        for column in result["columns"]:
            if to_file is True:
                f.write((" " * 4 + column["name"] +
                         f" = {column['type'].to_marshmallow_str()}\n"
                         ).encode("utf-8"))

            schema_dic[column["name"]] = column["type"].to_marshmallow()

        if (to_file is True) and (i != last_table_num):
            f.write("\n\n".encode("utf-8"))

        schemas[result["table"]] = (
            Schema.from_dict(schema_dic)()  # type: ignore[arg-type]
        )

    if to_file is True:
        f.close()
    return schemas
class UserPodResources(
        Schema.from_dict(
            # Memory and CPU resources that should be present in the response to creating a
            # jupyter noteboooks server.
            {
                "cpu": fields.Str(required=True),
                "memory": fields.Str(required=True),
                "storage": fields.Str(required=False),
                "gpu": fields.Str(required=False),
            })):
    @pre_load
    def resolve_gpu_fieldname(self, in_data, **kwargs):
        if "nvidia.com/gpu" in in_data.keys():
            in_data["gpu"] = in_data.pop("nvidia.com/gpu")
        return in_data
Example #26
0
    def validate_primitive_value(self, field_name, value):
        extra_kwargs = self.get_extra_kwargs_of(field_name)
        required = extra_kwargs.get('required')

        if value is None:
            if required:
                raise ValidationError('{} is required'.format(field_name))
            else:
                return

        marshmallow_conf = self._writable_fields[field_name]
        schema = Schema.from_dict({field_name: marshmallow_conf})()
        validated_value = schema.load({field_name: value})[field_name]

        return validated_value
Example #27
0
class CredDefValuePrimarySchema(OpenAPISchema):
    """Cred def value primary schema."""

    n = fields.Str(**NUM_STR_WHOLE)
    s = fields.Str(**NUM_STR_WHOLE)
    r = fields.Nested(
        Schema.from_dict({
            "master_secret": fields.Str(**NUM_STR_WHOLE),
            "number": fields.Str(**NUM_STR_WHOLE),
            "remainder": fields.Str(**NUM_STR_WHOLE),
        }),
        name="CredDefValuePrimaryRSchema",
    )
    rctxt = fields.Str(**NUM_STR_WHOLE)
    z = fields.Str(**NUM_STR_WHOLE)
Example #28
0
def validate_data(raw_data: dict, schema: Schema) -> dict:
    """Helper to load and validate raw_data using schema.

    Uses flask.abort to throw http errors 400 or 422 with appropriate messages.

    :param raw_data: dict of raw data from request body
    :param schema: marshmallow.Schema
    :return: validated data
    """
    if not raw_data:
        return abort(400, {'message': 'No input data provided'})
    try:
        return schema.load(raw_data)
    except ValidationError as e:
        return abort(422, {'message': e.messages})
Example #29
0
def accepts_logic(payload: dict = {},
                  temp: dict = {},
                  schema: Schema = None,
                  many=False):
    schema = _get_or_create_schema(schema, many=many)

    try:
        payload['temp'] = temp
        payload = schema.load(payload, unknown=INCLUDE)
        del payload['temp']
    except ValidationError as err:
        err.messages = _parse_error_to_camel_key(err.messages)
        raise LogicalValidationException(extra=err.messages)

    return payload
    def __init__(self, app, db, schema, record_name=None, logger_service: LoggerService = None,
                 resource_protector: CustomResourceProtector = None, create_scope: Scope = None,
                 fetch_scope: Scope = None, create_permissions=None, fetch_permissions=None):
        """

        :param app: Flask application reference
        :param db: Flask SQLAlchemy reference
        :param schema: Current model Marshmallow Schema with model reference
        """
        self.app = app
        self.service = ChassisService(app, db, schema.Meta.model)
        self.db = db
        if record_name is None:
            self.record_name = "Resource"
        else:
            self.record_name = record_name

        self.schema = schema

        # class ResponseSchema(ResponseWrapper):
        #     data = fields.Nested(schema)

        class RecordPageSchema(DjangoPageSchema):
            results = fields.List(fields.Nested(schema))

        # self.response_schema = ResponseSchema()
        self.page_response_schema = RecordPageSchema()
        self.logger_service = logger_service
        self.resource_protector = resource_protector
        self.create_scopes = create_scope
        self.fetch_scopes = fetch_scope
        self.create_permissions = create_permissions
        self.fetch_permissions = fetch_permissions
        # Fetch schema fields
        fetch_fields = dict(page_size=fields.Int(required=False), page=fields.Int(required=False),
                            ordering=fields.Str(required=False), q=fields.Str(required=False))
        if hasattr(self.schema.Meta.model, "created_at"):
            fetch_fields["created_after"] = fields.Date(required=False)
            fetch_fields["created_before"] = fields.Date(required=False)
        if hasattr(self.schema.Meta.model, "updated_at"):
            fetch_fields["updated_after"] = fields.Date(required=False)
            fetch_fields["updated_before"] = fields.Date(required=False)
        for column in getattr(self.schema.Meta.model, "__table__").c:
            if column.primary_key == False and column.name != "created_at" and column.name != "updated_at" and \
                    column.name != "is_deleted" and column.name != "created_by_id":
                fetch_fields[column.name] = fields.Str(required=False)

        self.fetch_schema = Schema.from_dict(fetch_fields)
Example #31
0
class ActionResultScalar(ActionResultBase):
    resultType = fields.Constant('scalar')
    result = fields.Nested(
        Schema.from_dict(
            {
                'links': fields.List(
                    fields.Nested(LinkSchema),
                    required=True,
                ),
                'value': fields.String(
                    required=True,
                    example="Done.",
                )
            },
            name='ActionResultScalarValue',
        ))
Example #32
0
class ActionResultScalar(ActionResultBase):
    result = fields.Nested(
        Schema.from_dict(
            {
                "links": fields.List(
                    fields.Nested(LinkSchema),
                    required=True,
                ),
                "value": fields.String(
                    required=True,
                    example="Done.",
                ),
            },
            name="ActionResultScalarValue",
        ),
        description="The scalar result of the action.",
    )
Example #33
0
def schema_from_umongo_get_attribute(self, attr, obj, default):
    """
    Overwrite default `Schema.get_attribute` method by this one to access
        umongo missing fields instead of returning `None`.

    example::

        class MySchema(marshsmallow.Schema):
            get_attribute = schema_from_umongo_get_attribute

            # Define the rest of your schema
            ...

    """
    ret = MaSchema.get_attribute(self, attr, obj, default)
    if ret is None and ret is not default and attr in obj.schema.fields:
        raw_ret = obj._data.get(attr)
        return default if raw_ret is missing else raw_ret
    else:
        return ret