class AttrSchema(Schema): id = fields.Int() label = fields.Str(required=True) created = fields.DateTime(dump_only=True) updated = fields.DateTime(dump_only=True) type = fields.Str(required=True) value_type = fields.Str(required=True) static_value = fields.Field() template_id = fields.Str(dump_only=True) metadata = fields.Nested(MetaSchema, many=True, attribute='children') @post_dump def remove_null_values(self, data): return { key: value for key, value in data.items() \ if (value is not None) and ((isinstance(value, list) and len(value)) or not isinstance(value, list)) }
def test_deserialize_fields_with_dump_only_param(self, unmarshal): data = { 'name': 'Mick', 'years': '42', } fields_dict = { 'name': fields.String(), 'years': fields.Integer(dump_only=True), 'always_invalid': fields.Field(validate=lambda f: False, dump_only=True) } result = unmarshal.deserialize(data, fields_dict) assert result['name'] == 'Mick' assert 'years' not in result assert 'always_invalid' not in unmarshal.errors
class UploadFileMixin: """Mixin for RHs using the generic file upload system. An RH using this mixin needs to override the ``get_file_context`` method to specify how the file gets stored. """ @use_kwargs({'file': fields.Field(required=True)}, location='files') def _process(self, file): if not self.validate_file(file): # XXX: we could include a nicer error message, but none of the upload # widgets show it right now, so no need to add more (translatable) strings # nobody sees raise UnprocessableEntity return self._save_file(file, file.stream) def _save_file(self, file, stream): from indico.modules.files.schemas import FileSchema context = self.get_file_context() content_type = mimetypes.guess_type( file.filename)[0] or file.mimetype or 'application/octet-stream' f = File(filename=file.filename, content_type=content_type) f.save(context, stream) db.session.add(f) db.session.flush() logger.info('File %r uploaded (context: %r)', f, context) return FileSchema().jsonify(f), 201 def get_file_context(self): """The context of where the file is being uploaded. :return: A tuple/list of path segments to use when storing the file. For example, if a file is being uploaded to an event, you'd return ``['event', EVENTID]`` so the file gets stored under ``event/EVENTID/...`` in the storage backend. """ raise NotImplementedError def validate_file(self, file): """Validate the uploaded file. If this function returns false, the upload is rejected. """ return True
class BotoTaskSchema(Schema): """ Botocore Serialization Object for ECS 'Task' shape. Note that there are many more parameters, but the executor only needs the members listed below. """ task_arn = fields.String(data_key='taskArn', required=True) last_status = fields.String(data_key='lastStatus', required=True) desired_status = fields.String(data_key='desiredStatus', required=True) containers = fields.List(fields.Nested(BotoContainerSchema), required=True) started_at = fields.Field(data_key='startedAt') stopped_reason = fields.String(data_key='stoppedReason') @post_load def make_task(self, data, **kwargs): """Overwrites marshmallow load() to return an instance of EcsFargateTask instead of a dictionary""" return EcsFargateTask(**data) class Meta: unknown = EXCLUDE
def test_from_wtforms_multi(self): field = fields.Field(validate=from_wtforms( [Length(min=4), NoneOf(["nil", "null", "NULL"])])) assert field.deserialize("thisisfine") == "thisisfine" with pytest.raises(MarshmallowValidationError, match="Field must be at least 4 characters long"): field.deserialize("bad") with pytest.raises( MarshmallowValidationError, match="Invalid value, can't be any of: nil, null, NULL.", ): field.deserialize("null") with pytest.raises(MarshmallowValidationError) as excinfo: field.deserialize("nil") # both errors are returned error = excinfo.value assert "Invalid value, can't be any of: nil, null, NULL." in error.messages assert "Field must be at least 4 characters long." in error.messages
def __filter_fields(self, field_names, obj, many=False): """Return only those field_name:field_obj pairs specified by ``field_names``. :param set field_names: Field names to include in the final return dictionary. :returns: An dict of field_name:field_obj pairs. """ if obj and many: try: # Homogeneous collection # Prefer getitem over iter to prevent breaking serialization # of objects for which iter will modify position in the collection # e.g. Pymongo cursors if hasattr(obj, '__getitem__') and callable(getattr(obj, '__getitem__')): obj_prototype = obj[0] else: obj_prototype = next(iter(obj)) except (StopIteration, IndexError): # Nothing to serialize return self.declared_fields obj = obj_prototype ret = self.dict_class() for key in field_names: if key in self.declared_fields: ret[key] = self.declared_fields[key] else: # Implicit field creation (class Meta 'fields' or 'additional') if obj: attribute_type = None try: if isinstance(obj, Mapping): attribute_type = type(obj[key]) else: attribute_type = type(getattr(obj, key)) except (AttributeError, KeyError) as err: err_type = type(err) raise err_type( '"{0}" is not a valid field for {1}.'.format(key, obj)) field_obj = self.TYPE_MAPPING.get(attribute_type, fields.Field)() else: # Object is None field_obj = fields.Field() # map key -> field (default to Raw) ret[key] = field_obj return ret
class ObservingRunGet(ObservingRunPost): owner_id = fields.Integer( description='The User ID of the owner of this run.') ephemeris = fields.Field(description='Observing run ephemeris data.') id = fields.Integer(description='Unique identifier for the run.') @pre_dump def serialize(self, data, **kwargs): data.ephemeris = {} data.ephemeris['sunrise_utc'] = data.sunrise.isot data.ephemeris['sunset_utc'] = data.sunset.isot data.ephemeris[ 'twilight_evening_nautical_utc'] = data.twilight_evening_nautical.isot data.ephemeris[ 'twilight_morning_nautical_utc'] = data.twilight_morning_nautical.isot data.ephemeris[ 'twilight_evening_astronomical_utc'] = data.twilight_evening_astronomical.isot data.ephemeris[ 'twilight_morning_astronomical_utc'] = data.twilight_morning_astronomical.isot return data
class DetectionImage(Schema): image = fields.Field(required=True) @validates("image") def validate_image(self, image): try: Image.open(image) except Exception as e: # noqa: F841 raise ValidationError("Unsupported file type") if image.mimetype not in consts.ALLOWED_MIMETYPE: raise ValidationError("Unsupported image type") if len(image.read()) > config.MAX_IMAGE_SIZE: raise ValidationError("Image size exceeds the limit") @post_load def process(self, data, **kwargs): data["image"] = Image.open(data["image"]) return data
def get_parameters(url, handler, spec): defaults = get_default_args(handler._original_handler) args_spec = getfullargspec(handler._original_handler) parameters = [] for name in args_spec.args: kind = args_spec.annotations.get(name, fields.Field()) if name in get_available_directives() or name == "request": continue if isinstance(kind, fields.Field): parameter_place = where_is_parameter(name, url) kind.metadata = {"location": where_is_parameter(name, url)} kind.required = name not in defaults parameter = converter.field2parameter( kind, name=name, default_in=parameter_place, use_refs=False ) if name in defaults: parameter["default"] = defaults[name] parameters.append(parameter) # body elif name == "body" and ( isinstance(kind, Schema) or isinstance(kind, SchemaMeta) ): if isinstance(kind, Schema): schema_name = kind.__class__.__name__ schema = kind elif isinstance(kind, SchemaMeta): schema_name = kind.__name__ schema = kind() spec.definition(schema_name, schema=schema) ref_definition = "#/definitions/{}".format(schema_name) ref_schema = {"$ref": ref_definition} parameters.append( {"in": "body", "name": "body", "required": True, "schema": ref_schema} ) return parameters
def __filter_fields(self, field_names, obj, many=False): """Return only those field_name:field_obj pairs specified by ``field_names``. :param set field_names: Field names to include in the final return dictionary. :returns: An dict of field_name:field_obj pairs. """ # Convert obj to a dict obj_marshallable = utils.to_marshallable_type(obj, field_names=field_names) if obj_marshallable and many: try: # Homogeneous collection obj_prototype = obj_marshallable[0] except IndexError: # Nothing to serialize return self.declared_fields obj_dict = utils.to_marshallable_type(obj_prototype, field_names=field_names) else: obj_dict = obj_marshallable ret = self.dict_class() for key in field_names: if key in self.declared_fields: ret[key] = self.declared_fields[key] else: if obj_dict: try: attribute_type = type(obj_dict[key]) except KeyError: raise AttributeError( '"{0}" is not a valid field for {1}.'.format( key, obj)) field_obj = self.TYPE_MAPPING.get(attribute_type, fields.Field)() else: # Object is None field_obj = fields.Field() # map key -> field (default to Raw) ret[key] = field_obj return ret
def __filter_fields(self, field_names, obj, many=False): """Return only those field_name:field_obj pairs specified by ``field_names``. :param set field_names: Field names to include in the final return dictionary. :param object|Mapping|list obj The object to base filtered fields on. :returns: An dict of field_name:field_obj pairs. """ if obj and many: try: # list obj = obj[0] except IndexError: # Nothing to serialize return dict((k, v) for k, v in self.declared_fields.items() if k in field_names) ret = self.dict_class() for key in field_names: if key in self.declared_fields: ret[key] = self.declared_fields[key] else: # Implicit field creation (class Meta 'fields' or 'additional') if obj: attribute_type = None try: if isinstance(obj, Mapping): attribute_type = type(obj[key]) else: attribute_type = type(getattr(obj, key)) except (AttributeError, KeyError) as err: err_type = type(err) raise err_type( '"{0}" is not a valid field for {1}.'.format( key, obj)) field_obj = self.TYPE_MAPPING.get(attribute_type, fields.Field)() else: # Object is None field_obj = fields.Field() # map key -> field (default to Raw) ret[key] = field_obj return ret
class BaseParamSchema(Schema): """ Defines a base parameter schema. This specifies the required fields and their types. { "title": str, "description": str, "notes": str, "type": str (limited to 'int', 'float', 'bool', 'str'), "number_dims": int, "value": `BaseValidatorSchema`, "value" type depends on "type" key, "range": range schema ({"min": ..., "max": ..., "other ops": ...}), "out_of_range_minmsg": str, "out_of_range_maxmsg": str, "out_of_range_action": str (limited to 'stop' or 'warn') } This class is defined further by a JSON file indicating extra fields that are required by the implementer of the schema. """ title = fields.Str(required=True) description = fields.Str(required=True) notes = fields.Str(required=True) _type = fields.Str( required=True, validate=validate.OneOf( choices=["str", "float", "int", "bool", "date"]), attribute="type", data_key="type", ) number_dims = fields.Integer(required=True) value = fields.Field(required=True) # will be specified later validators = fields.Nested(ValueValidatorSchema(), required=True) out_of_range_minmsg = fields.Str(required=False) out_of_range_maxmsg = fields.Str(required=False) out_of_range_action = fields.Str( required=False, validate=validate.OneOf(choices=["stop", "warn"]))
class AttrSchema(Schema): id = fields.Int() import_id = fields.Int(load_only=True) label = fields.Str(required=True, validate=validate_attr_label, allow_none=False, missing=None) created = fields.DateTime(dump_only=True) updated = fields.DateTime(dump_only=True) type = fields.Str(required=True) value_type = fields.Str(required=True) static_value = fields.Field(allow_none=True) template_id = fields.Str(dump_only=True) metadata = fields.Nested(MetaSchema, many=True, attribute='children', validate=validate_children_attr_label) @post_load def set_import_id(self, data): return set_id_with_import_id(data) @post_dump def remove_null_values(self, data): return { key: value for key, value in data.items() \ if (value is not None) and ((isinstance(value, list) and len(value)) or not isinstance(value, list)) }
class ProcessGraphShortSchema(BaseSchema): """Schema including basic information about a process graph.""" id_internal = fields.String(attribute='id', load_only=True) id_ = fields.String(required=True, data_key="id", attribute='id_openeo', validate=validate.Regexp(regex='^\\w+$')) summary = fields.String() description = fields.String() categories = fields.Pluck(CategorySchema, 'name', many=True) deprecated = fields.Boolean(default=False) experimental = fields.Boolean(default=False) returns = fields.Nested(ReturnSchema) parameters = fields.List(fields.Nested(ParameterSchema)) process_definition = fields.Field(load_only=True) user_id = fields.String(load_only=True) @pre_load def add_process_graph_id(self, in_data: dict, **kwargs: dict) -> dict: """Generate and store an internal process_graph_id.""" if not ('id_internal' in in_data and in_data['id_internal']): in_data['id_internal'] = 'pg-' + str(uuid4()) return in_data @post_dump def fix_old_process_graph_ids(self, data: dict, **kwargs: dict) -> dict: """Reformat id_openeo to match required regex. Due to backward compatibility some process_graph ids may contain '-' which are not allowed. '-' are replaced by '_' and the ids are prefixed with 'regex_' """ id_pattern = re.compile('^\\w+$') if id_pattern.match(data['id']) is None: regex_id = data['id'].replace('-', '_') data['id'] = 'regex_' + regex_id return data
class ObservationPlanPost(_Schema): gcnevent_id = fields.Integer( required=True, metadata={'description': "ID of the GcnEvent."}, ) payload = fields.Field( required=False, metadata={'description': "Content of the followup request."} ) status = fields.String( missing="pending submission", metadata={'description': "The status of the request."}, required=False, ) allocation_id = fields.Integer( required=True, metadata={'description': "Followup request allocation ID."}, ) localization_id = fields.Integer( required=True, metadata={'description': "Localization ID."}, ) target_group_ids = fields.List( fields.Integer, required=False, metadata={ 'description': ( 'IDs of groups to share the results of the followup request with.' ) }, )
class PatternConfig(Schema): conf_name = fields.Str() issue_category = fields.Str() ic_idx = fields.Integer() name = fields.Str() function = fields.Str() operation = fields.Field() keyword = fields.Field() # parse pattern use for-loop is faster than regex union. key_value = fields.List(fields.Field()) sequence = fields.Field() sequence_idx = fields.Field() sequence_status = fields.Field() sequence_order = fields.List(fields.Integer()) @post_load def make_pattern(self, data, **kwargs): return PatternCollector(**data)
def test_allow_none_is_true_if_missing_is_true(self): field = fields.Field(missing=None) assert field.allow_none is True field.deserialize(None) is None
class MySchema(Schema): foo = fields.Field() bar = fields.Field()
class MySchema(Schema): foo = fields.Field(required=True) bar = fields.Field(required=True) baz = fields.Field(required=True)
class MySchema(Schema): foo = fields.Field(required=True, validate=lambda f: False)
class MySchema(Schema): foo = fields.Field(required=True, validate=[ validate_with_bool, validate_with_error, ])
class RequestsHeaderSchema(ma.SQLAlchemySchema): class Meta: model = Request # sqla_session = db.scoped_session # additional = ['stateCd'] fields = ( 'additionalInfo', 'applicants', 'checkedOutBy', 'checkedOutDt', 'comments', 'consentFlag', 'consent_dt', 'corpNum', 'details', 'entity_type_cd', 'expirationDate', 'furnished', 'hasBeenReset', 'homeJurisNum', 'id', 'lastUpdate', 'names', 'natureBusinessInfo', 'nrNum', 'nroLastUpdate', 'nwpta', 'previousNr', 'previousRequestId', 'previousStateCd', 'priorityCd', 'priorityDate', 'requestTypeCd', 'request_action_cd', 'source', 'state', 'stateCd', 'tradeMark', 'submitter_userid', 'submitCount', 'submittedDate', 'userId', 'xproJurisdiction' ) additionalInfo = fields.String(allow_none=True) applicants = fields.Field(allow_none=True) checkedOutBy = fields.String(allow_none=True) checkedOutDt = fields.Field(allow_none=True) comments = fields.Field(allow_none=True) consentFlag = fields.String(allow_none=True) consent_dt = fields.Field(allow_none=True) corpNum = fields.String(allow_none=True) details = fields.Field(allow_none=True) entity_type_cd = fields.String(allow_none=True) expirationDate = fields.Field(allow_none=True) furnished = fields.String(allow_none=True) hasBeenReset = fields.Boolean(allow_none=True) homeJurisNum = fields.String(allow_none=True) lastUpdate = fields.Field(allow_none=True) natureBusinessInfo = fields.String(allow_none=True) nroLastUpdate = fields.Field(allow_none=True) nwpta = fields.Field(allow_none=True) previousNr = fields.String(allow_none=True) previousRequestId = fields.Integer(allow_none=True) previousStateCd = fields.String(allow_none=True) priorityCd = fields.String(allow_none=True) priorityDate = fields.Field(allow_none=True) requestTypeCd = fields.String(allow_none=True) request_action_cd = fields.String(allow_none=True) source = fields.String(allow_none=True) stateCd = fields.String(allow_none=True) tradeMark = fields.String(allow_none=True) submitter_userid = fields.String(allow_none=True) userId = fields.String(allow_none=True) xproJurisdiction = fields.String(allow_none=True)
def test_callable_default(self, user): field = fields.Field(default=lambda: 'nan') assert field.serialize('age', {}) == 'nan'
def test_default(self, user): field = fields.Field(default='nan') assert field.serialize('age', {}) == 'nan'
def test_error_raised_if_uncallable_validator_passed(self): with pytest.raises(ValueError): fields.Field(validate='notcallable')
class MySchema(Schema): foo = fields.Field() @post_load(pass_original=True) def post_load(self, data, input_data): data['_post_load'] = input_data['post_load']
def test_missing_data_are_skipped(self, marshal): assert marshal({}, {'foo': fields.Field()}) == {} assert marshal({}, {'foo': fields.Str()}) == {} assert marshal({}, {'foo': fields.Int()}) == {} assert marshal({}, {'foo': fields.Int(as_string=True)}) == {} assert marshal({}, {'foo': fields.Decimal(as_string=True)}) == {}
class MySchema(Schema): foo = fields.Field(validate=validator)
class InvitationReceiveRequestSchema(InvitationMessageSchema): """Invitation request schema.""" service = fields.Field()
def test_serialize_does_not_apply_validators(self, user): field = fields.Field(validate=lambda x: False) # No validation error raised assert field.serialize('age', user) == user.age