def test_function_field_passed_func(self, user): field = fields.Function(lambda obj: obj.name.upper()) assert 'FOO' == field.serialize('key', user)
def test_function_validator(self): field = fields.Function(lambda d: d.name.upper(), validate=lambda n: len(n) == 3) assert field.deserialize('joe') with pytest.raises(ValidationError): field.deserialize('joseph')
def test_function_field_deserialization_with_callable(self): field = fields.Function(lambda x: None, deserialize=lambda val: val.upper()) assert field.deserialize('foo') == 'FOO'
class TestValidation: def test_integer_with_validator(self): field = fields.Integer(validate=lambda x: 18 <= x <= 24) out = field.deserialize('20') assert out == 20 with pytest.raises(ValidationError): field.deserialize(25) @pytest.mark.parametrize('field', [ fields.Integer(validate=[lambda x: x <= 24, lambda x: 18 <= x]), fields.Integer(validate=(lambda x: x <= 24, lambda x: 18 <= x, )), fields.Integer(validate=validators_gen) ]) def test_integer_with_validators(self, field): out = field.deserialize('20') assert out == 20 with pytest.raises(ValidationError): field.deserialize(25) @pytest.mark.parametrize('field', [ fields.Float(validate=[lambda f: f <= 4.1, lambda f: f >= 1.0]), fields.Float(validate=(lambda f: f <= 4.1, lambda f: f >= 1.0, )), fields.Float(validate=validators_gen_float) ]) def test_float_with_validators(self, field): assert field.deserialize(3.14) with pytest.raises(ValidationError): field.deserialize(4.2) def test_string_validator(self): field = fields.String(validate=lambda n: len(n) == 3) assert field.deserialize('Joe') == 'Joe' with pytest.raises(ValidationError): field.deserialize('joseph') def test_function_validator(self): field = fields.Function(lambda d: d.name.upper(), validate=lambda n: len(n) == 3) assert field.deserialize('joe') with pytest.raises(ValidationError): field.deserialize('joseph') @pytest.mark.parametrize('field', [ fields.Function(lambda d: d.name.upper(), validate=[lambda n: len(n) == 3, lambda n: n[1].lower() == 'o']), fields.Function(lambda d: d.name.upper(), validate=(lambda n: len(n) == 3, lambda n: n[1].lower() == 'o')), fields.Function(lambda d: d.name.upper(), validate=validators_gen_str) ]) def test_function_validators(self, field): assert field.deserialize('joe') with pytest.raises(ValidationError): field.deserialize('joseph') def test_method_validator(self): class MethodSerializer(Schema): name = fields.Method('get_name', deserialize='get_name', validate=lambda n: len(n) == 3) def get_name(self, val): return val.upper() assert MethodSerializer(strict=True).load({'name': 'joe'}) with pytest.raises(ValidationError) as excinfo: MethodSerializer(strict=True).load({'name': 'joseph'}) assert 'is False' in str(excinfo)
class NotificationSchema(Schema): created_at = fields.DateTime() date = fields.Function(lambda obj: str(obj.created_at.date())) body = fields.String() demande = fields.Nested(DemandeSchema)
class CityDistrictSchema(Schema): id = fields.Function(func=lambda x: str(x.id)) class Meta: fields = ['id', 'title']
class RecordSchemaMarcXMLV1(Schema): """Schema for records in MARC.""" control_number = fields.Method('get_id') def get_id(self, obj): pids = obj['metadata'].get('_pid') p = [p['value'] for p in pids if p['type'] == 'b2rec'] return str(p[0]) date_and_time_of_latest_transaction = fields.Function( lambda o: parse(o['updated']).strftime("%Y%m%d%H%M%S.0")) main_entry_personal_name = fields.Method('get_main_entry_personal_name') added_entry_personal_name = fields.Method('get_added_entry_personal_name') title_statement = fields.Function( lambda o: o['metadata']['titles'][0]) publication_distribution_imprint = fields.Function( lambda o: dict(name_of_publisher_distributor=o['metadata'].get('publisher'), date_of_publication_distribution=o['metadata'].get('publication_date'))) media_type = fields.Function( lambda o: dict(media_type_term=[x['resource_type_general'] for x in o['metadata'].get('resource_types', [])])) summary = fields.Function( lambda o: [dict(summary=x.get('description')) for x in o['metadata'].get('descriptions', [])]) study_program_information_note = fields.Function( lambda o: [dict(program_name=o['metadata'].get('disciplines', []))]) terms_governing_use_and_reproduction_note = fields.Function( lambda o: dict(terms_governing_use_and_reproduction= o['metadata'].get('license', {}).get('license'), uniform_resource_identifier= o['metadata'].get('license', {}).get('license_uri'))) information_relating_to_copyright_status = fields.Function( lambda o: dict(copyright_status='open' if o['metadata']['open_access'] else 'closed')) language_note = fields.Function( lambda o: [dict(language_note=o['metadata'].get('language'))]) index_term_uncontrolled = fields.Function( lambda o: [dict(uncontrolled_term=x) for x in o['metadata'].get('keywords', [])]) other_standard_identifier = fields.Function( lambda o: [dict(standard_number_or_code=x['alternate_identifier']) for x in o['metadata'].get('alternate_identifiers', [])]) electronic_location_and_access = fields.Function( lambda o: [dict(uniform_resource_identifier=f['ePIC_PID'], file_size=str(f['size']), access_method="HTTP") if f.get('ePIC_PID') else None for f in o['metadata'].get('_files', [])]) # Custom fields: embargo_date = fields.Raw(attribute='metadata.embargo_date') _oai = fields.Raw(attribute='metadata._oai') def get_main_entry_personal_name(self, o): creators = o['metadata'].get('creators', []) if len(creators) > 0: return dict(personal_name=creators[0]['creator_name']) return dict() def get_added_entry_personal_name(self, o): """Get added_entry_personal_name.""" items = [] creators = o['metadata'].get('creators', []) if len(creators) > 1: for c in creators[1:]: items.append(dict(personal_name=c['creator_name'])) contributors = o['metadata'].get('contributors', []) for c in contributors: items.append(dict(personal_name=c['contributor_name'])) return items @post_dump(pass_many=True) def remove_empty_fields(self, data, many): """Dump + Remove empty fields.""" _filter_empty(data) return data
class WorkerSchema(Schema): id = fields.Integer() ean13 = fields.String() type = fields.Function(lambda obj: obj.type.name.swapcase())
class UserSchema(mm.Schema): email = fields.String() name = fields.Function(lambda u: f'{u["first_name"]} {u["last_name"]}') uid = fields.String() avatar_url = fields.Function(lambda user: url_for( 'api.user_avatar', payload=avatar_payload_from_user_info(user)))
class UserSchema(mm.Schema): email = fields.String() name = fields.Function(lambda u: f'{u["first_name"]} {u["last_name"]}') initials = fields.Function( lambda u: f'{u["first_name"][0]} {u["last_name"][0]}') uid = fields.String()
class ServiceSchema(AutoSchema): _id = fields.Integer(attribute='id', dump_only=True) _rev = fields.String(default='', dump_only=True) owned = fields.Boolean(default=False) owner = PrimaryKeyRelatedField('username', dump_only=True, attribute='creator') port = fields.Integer( dump_only=True, strict=True, required=True, validate=[ Range(min=0, error="The value must be greater than or equal to 0") ]) # Port is loaded via ports ports = MutableField(fields.Integer( strict=True, required=True, validate=[ Range(min=0, error="The value must be greater than or equal to 0") ]), fields.Method(deserialize='load_ports'), required=True, attribute='port') status = fields.String(default='open', validate=OneOf(Service.STATUSES), required=True, allow_none=False) parent = fields.Integer( attribute='host_id') # parent is not required for updates host_id = fields.Integer(attribute='host_id', dump_only=True) vulns = fields.Integer(attribute='vulnerability_count', dump_only=True) credentials = fields.Integer(attribute='credentials_count', dump_only=True) metadata = SelfNestedField(MetadataSchema()) type = fields.Function(lambda obj: 'Service', dump_only=True) def load_ports(self, value): if not isinstance(value, list): raise ValidationError('ports must be a list') if len(value) != 1: raise ValidationError('ports must be a list with exactly one' 'element') port = value.pop() if port < 0: raise ValidationError( 'The value must be greater than or equal to 0') return str(port) @post_load def post_load_parent(self, data): """Gets the host_id from parent attribute. Pops it and tries to get a Host with that id in the corresponding workspace. """ host_id = data.pop('host_id', None) if self.context['updating']: if host_id is None: # Partial update? return if host_id != self.context['object'].parent.id: raise ValidationError('Can\'t change service parent.') else: if not host_id: raise ValidationError( 'Parent id is required when creating a service.') try: data['host'] = Host.query.join(Workspace).filter( Workspace.name == self.context['workspace_name'], Host.id == host_id).one() except NoResultFound: raise ValidationError( 'Host with id {} not found'.format(host_id)) class Meta: model = Service fields = ('id', '_id', 'status', 'parent', 'type', 'protocol', 'description', '_rev', 'owned', 'owner', 'credentials', 'vulns', 'name', 'version', '_id', 'port', 'ports', 'metadata', 'summary', 'host_id')
class PaymentTableSchema(BaseSchema): class WD(WorkReportSchema): work_days_list = fields.List( fields.Nested(DailyLogDefaultSchema(many=False))) leave_days_list = fields.List( fields.Nested(LeaveInfoSchema(many=False))) extra_work_days_list = fields.List( fields.Nested(ExtraWorkInfoSchema(many=False))) absent_days_list = fields.List( fields.Nested(DailyLogDefaultSchema(many=False))) shift_days_list = fields.List( fields.Nested(LeaveInfoSchema(many=False))) rest_days_list = fields.List( fields.Nested(DailyLogDefaultSchema(many=False))) class C(Schema): class OS(Schema): position = fields.String() position_level = fields.String() money = fields.Float() s_money = fields.Float() offer = fields.Nested(OS(many=False)) class E(Schema): real_name = fields.String() welfare_rate = fields.String() class PLS(Schema): money = fields.Float() name = fields.String() career = fields.Nested(C(many=False)) position_level = fields.Nested(PLS(many=False)) year_month = fields.Int() engineer = fields.Nested(E(many=False)) engineer_id = fields.Int() company_pay = fields.Float() income = fields.Float() tax = fields.Float() fee = fields.Float() welfare = fields.Float() amerce = fields.Float() status = fields.Function(lambda x: PaymentStatus.int2str(x.status)) engineer_tax = fields.Float() engineer_get = fields.Float() break_up_fee = fields.Float() engineer_income_with_tax = fields.Float() finance_fee = fields.Float() hr_fee = fields.Float() service_fee_rate = fields.Float() employ_type = fields.Integer() out_duty_days = fields.Float() finance_rate = fields.Float() tax_fee_rate = fields.Float() tax_rate = fields.Float() use_hr_servce = fields.Integer() ware_fare = fields.Float() pm = fields.String() project = fields.String() work_report = fields.Nested(WD(many=False)) station_salary = fields.Float() extra_salary = fields.Float() tax_free_rate = fields.Float()
class InterviewPmPutSchema(Schema): status = fields.Function(lambda x: x, deserialize=lambda x: InterviewStatus.str2int(x))
def test_function_field_does_not_swallow_attribute_error(self, user): def raise_error(obj): raise AttributeError() field = fields.Function(serialize=raise_error) with pytest.raises(AttributeError): field.serialize('key', user)
def test_function_field_load_only(self): field = fields.Function(deserialize=lambda obj: None) assert field.load_only
class UserSchema(marshmallow.Schema): """Schema of a user to be used to extract API information. This class is a ``marshmallow`` schema which automatically gets its structure from the ``User`` class. Plus, we add some useful information or link. This schema is only used for administration listing. """ isadmin = fields.Function(lambda user: user.is_admin()) """ Wraps :py:meth:`collectives.models.user.User.is_admin` :type: boolean""" roles_uri = fields.Function( lambda user: url_for("administration.add_user_role", user_id=user.id)) """ URI to role management page for this user :type: string """ delete_uri = fields.Function( lambda user: url_for("administration.delete_user", user_id=user.id)) """ URI to delete this user (WIP) :type: string """ manage_uri = fields.Function( lambda user: url_for("administration.manage_user", user_id=user.id)) """ URI to modify this user :type: string """ profile_uri = fields.Function( lambda user: url_for("profile.show_user", user_id=user.id)) """ URI to see user profile :type: string """ leader_profile_uri = fields.Function( lambda user: url_for("profile.show_leader", leader_id=user.id) if user.can_create_events() else None) """ URI to see user profile :type: string """ avatar_uri = fields.Function(avatar_url) """ URI to a resized version (30px) of user avatar :type: string """ roles = fields.Function( lambda user: RoleSchema(many=True).dump(user.roles)) """ List of roles of the User. Roles are encoded as JSON. :type: list(dict())""" full_name = fields.Function(lambda user: user.full_name()) """ User full name :type: string""" class Meta: """Fields to expose""" fields = ( "id", "mail", "isadmin", "enabled", "roles_uri", "avatar_uri", "manage_uri", "profile_uri", "delete_uri", "first_name", "last_name", "roles", "isadmin", "leader_profile_uri", "full_name", )
def test_function_field_passed_uncallable_object(self): with pytest.raises(ValueError): fields.Function("uncallable")
class UserAdditionalSchema(Schema): lowername = fields.Function(lambda obj: obj.name.lower()) class Meta: additional = ("name", "age", "created", "email")
class CategorySchema(ModelSchema): """Provides detailed information about a category, including all the children""" class Meta(ModelSchema.Meta): model = Category fields = ('id', 'name', 'children', 'parent_id', 'parent', 'level', 'event_count', 'location_count', 'resource_count', 'all_resource_count', 'study_count', 'color', '_links', 'last_updated', 'display_order') id = fields.Integer(required=False, allow_none=True) parent_id = fields.Integer(required=False, allow_none=True) children = ma.Nested(lambda: CategorySchema(), many=True, dump_only=True, exclude=('parent', 'color')) parent = ma.Nested(ParentCategorySchema, dump_only=True) level = fields.Function(lambda obj: obj.calculate_level() if isinstance(obj, Category) else 0, dump_only=True) event_count = fields.Method('get_event_count', dump_only=True) location_count = fields.Method('get_location_count', dump_only=True) resource_count = fields.Method('get_resource_count', dump_only=True) all_resource_count = fields.Method('get_all_resource_count', dump_only=True) study_count = fields.Method('get_study_count', dump_only=True) _links = ma.Hyperlinks({ 'self': ma.URLFor('api.categoryendpoint', id='<id>'), 'collection': ma.URLFor('api.categorylistendpoint') }) def get_event_count(self, obj): if obj is None: return missing query = db.session.query(ResourceCategory).filter(ResourceCategory.type == 'event')\ .filter(ResourceCategory.category_id == obj.id) count_q = query.statement.with_only_columns([func.count()]).order_by(None) return query.session.execute(count_q).scalar() def get_location_count(self, obj): if obj is None: return missing query = db.session.query(ResourceCategory).filter(ResourceCategory.type == 'location')\ .filter(ResourceCategory.category_id == obj.id) count_q = query.statement.with_only_columns([func.count()]).order_by(None) return query.session.execute(count_q).scalar() def get_resource_count(self, obj): if obj is None: return missing query = db.session.query(ResourceCategory).filter(ResourceCategory.type == 'resource')\ .filter(ResourceCategory.category_id == obj.id) count_q = query.statement.with_only_columns([func.count()]).order_by(None) return query.session.execute(count_q).scalar() def get_all_resource_count(self, obj): if obj is None: return missing query = db.session.query(ResourceCategory).join(ResourceCategory.resource)\ .filter(ResourceCategory.category_id == obj.id) count_q = query.statement.with_only_columns([func.count()]).order_by(None) return query.session.execute(count_q).scalar() def get_study_count(self, obj): if obj is None: return missing query = db.session.query(StudyCategory).join(StudyCategory.study)\ .filter(StudyCategory.category_id == obj.id) count_q = query.statement.with_only_columns([func.count()]).order_by(None) return query.session.execute(count_q).scalar()
class VulnerabilityFilterSet(FilterSet): class Meta(FilterSetMeta): model = VulnerabilityWeb # It has all the fields # TODO migration: Check if we should add fields owner, # command, impact, issuetracker, tags, date, host # evidence, policy violations, hostnames fields = ( "id", "status", "website", "pname", "query", "path", "service", "data", "severity", "confirmed", "name", "request", "response", "parameters", "params", "resolution", "ease_of_resolution", "description", "command_id", "target", "creator", "method", "easeofresolution", "query_string", "parameter_name", "service_id", "status_code" ) strict_fields = ( "severity", "confirmed", "method", "status", "easeofresolution", "ease_of_resolution", "service_id", ) default_operator = CustomILike # next line uses dict comprehensions! column_overrides = { field: _strict_filtering for field in strict_fields } operators = (CustomILike, operators.Equal) id = IDFilter(fields.Int()) target = TargetFilter(fields.Str()) type = TypeFilter(fields.Str(validate=[OneOf(['Vulnerability', 'VulnerabilityWeb'])])) creator = CreatorFilter(fields.Str()) service = ServiceFilter(fields.Str()) severity = Filter(SeverityField()) easeofresolution = Filter(fields.String( attribute='ease_of_resolution', validate=OneOf(Vulnerability.EASE_OF_RESOLUTIONS), allow_none=True)) pname = Filter(fields.String(attribute='parameter_name')) query = Filter(fields.String(attribute='query_string')) status_code = StatusCodeFilter(fields.Int()) params = Filter(fields.String(attribute='parameters')) status = Filter(fields.Function( deserialize=lambda val: 'open' if val == 'opened' else val, validate=OneOf(Vulnerability.STATUSES + ['opened']) )) hostnames = HostnamesFilter(fields.Str()) confirmed = Filter(fields.Boolean()) def filter(self): """Generate a filtered query from request parameters. :returns: Filtered SQLALchemy query """ # TODO migration: this can became a normal filter instead of a custom # one, since now we can use creator_command_id command_id = request.args.get('command_id') query = super(VulnerabilityFilterSet, self).filter() if command_id: # query = query.filter(CommandObject.command_id == int(command_id)) query = query.filter(VulnerabilityGeneric.creator_command_id == int(command_id)) # TODO migration: handle invalid int() return query
class RecordSchemaMARC21(Schema): """Schema for records in MARC.""" control_number = fields.Function(lambda o: str(o['metadata'].get('recid'))) date_and_time_of_latest_transaction = fields.Function( lambda obj: parse(obj['updated']).strftime("%Y%m%d%H%M%S.0")) information_relating_to_copyright_status = fields.Function( lambda o: dict(copyright_status=o['metadata']['access_right'])) index_term_uncontrolled = fields.Function(lambda o: [ dict(uncontrolled_term=kw) for kw in o['metadata'].get('keywords', []) ]) subject_added_entry_topical_term = fields.Method( 'get_subject_added_entry_topical_term') terms_governing_use_and_reproduction_note = fields.Function( lambda o: dict(uniform_resource_identifier=o['metadata'].get( 'license', {}).get('url'), terms_governing_use_and_reproduction=o['metadata'].get( 'license', {}).get('title'))) title_statement = fields.Function( lambda o: dict(title=o['metadata'].get('title'))) general_note = fields.Function( lambda o: dict(general_note=o['metadata'].get('notes'))) information_relating_to_copyright_status = fields.Function( lambda o: dict(copyright_status=o['metadata'].get('access_right'))) publication_distribution_imprint = fields.Method( 'get_publication_distribution_imprint') funding_information_note = fields.Function(lambda o: [ dict(text_of_note=v.get('title'), grant_number=v.get('code')) for v in o['metadata'].get('grants', []) ]) other_standard_identifier = fields.Method('get_other_standard_identifier') added_entry_meeting_name = fields.Method('get_added_entry_meeting_name') main_entry_personal_name = fields.Method('get_main_entry_personal_name') added_entry_personal_name = fields.Method('get_added_entry_personal_name') summary = fields.Function( lambda o: dict(summary=o['metadata'].get('description'))) host_item_entry = fields.Method('get_host_item_entry') dissertation_note = fields.Function( lambda o: dict(name_of_granting_institution=o['metadata'].get( 'thesis', {}).get('university'))) language_code = fields.Function( lambda o: dict(language_code_of_text_sound_track_or_separate_title=\ o['metadata'].get('language'))) # Custom # ====== resource_type = fields.Raw(attribute='metadata.resource_type') communities = fields.Raw(attribute='metadata.communities') references = fields.Raw(attribute='metadata.references') embargo_date = fields.Raw(attribute='metadata.embargo_date') journal = fields.Raw(attribute='metadata.journal') #_oai = fields.Raw(attribute='metadata._oai') _files = fields.Method('get_files') leader = fields.Method('get_leader') conference_url = fields.Raw(attribute='metadata.meeting.url') def get_leader(self, o): """Return the leader information.""" rt = o['metadata']['resource_type']['type'] rec_types = { 'image': 'two-dimensional_nonprojectable_graphic', 'video': 'projected_medium', 'dataset': 'computer_file', 'software': 'computer_file', } type_of_record = rec_types[rt] if rt in rec_types \ else 'language_material' res = { 'record_length': '00000', 'record_status': 'new', 'type_of_record': type_of_record, 'bibliographic_level': 'monograph_item', 'type_of_control': 'no_specified_type', 'character_coding_scheme': 'marc-8', 'indicator_count': 2, 'subfield_code_count': 2, 'base_address_of_data': '00000', 'encoding_level': 'unknown', 'descriptive_cataloging_form': 'unknown', 'multipart_resource_record_level': 'not_specified_or_not_applicable', 'length_of_the_length_of_field_portion': 4, 'length_of_the_starting_character_position_portion': 5, 'length_of_the_implementation_defined_portion': 0, 'undefined': 0, } return res def get_files(self, o): """Get the files provided the record is open access.""" if o['metadata']['access_right'] != 'open': return missing res = [] for f in o['metadata'].get('_files', []): res.append( dict( uri=u'{0}/record/{1}/files/{2}'.format( current_app.config.get('THEME_SITEURL', 'zenodo.org'), o['metadata'].get('recid', ''), f['key']), size=f['size'], checksum=f['checksum'], type=f['type'], )) return res or missing def get_host_item_entry(self, o): """Get host items.""" res = [] for v in o['metadata'].get('related_identifiers', []): res.append( dict( main_entry_heading=v.get('identifier'), relationship_information=v.get('relation'), note=v.get('scheme'), )) imprint = o['metadata'].get('imprint', {}) part_of = o['metadata'].get('part_of', {}) if part_of and imprint: res.append( dict( main_entry_heading=imprint.get('place'), edition=imprint.get('publisher'), title=part_of.get('title'), related_parts=part_of.get('pages'), international_standard_book_number=imprint.get('isbn'), )) return res or missing def get_publication_distribution_imprint(self, o): """Get publication date and imprint.""" res = [] pubdate = o['metadata'].get('publication_date') if pubdate: res.append(dict(date_of_publication_distribution=pubdate)) imprint = o['metadata'].get('imprint') part_of = o['metadata'].get('part_of') if not part_of and imprint: res.append( dict( place_of_publication_distribution=imprint.get('place'), name_of_publisher_distributor=imprint.get('publisher'), date_of_publication_distribution=pubdate, )) return res or missing def get_subject_added_entry_topical_term(self, o): """Get licenses and subjects.""" res = [] license = o['metadata'].get('license', {}).get('id') if license: res.append( dict( topical_term_or_geographic_name_entry_element='cc-by', source_of_heading_or_term='opendefinition.org', level_of_subject='Primary', thesaurus='Source specified in subfield $2', )) def _subject(term, id_, scheme): return dict( topical_term_or_geographic_name_entry_element=term, authority_record_control_number_or_standard_number=( '({0}){1}'.format(scheme, id_)), level_of_subject='Primary', ) for s in o['metadata'].get('subjects', []): res.append( _subject( s.get('term'), s.get('identifier'), s.get('scheme'), )) return res or missing def get_other_standard_identifier(self, o): """Get other standard identifiers.""" res = [] def stdid(pid, scheme, q=None): return dict( standard_number_or_code=pid, source_of_number_or_code=scheme, qualifying_information=q, ) m = o['metadata'] if m.get('doi'): res.append(stdid(m['doi'], 'doi')) for id_ in m.get('alternate_identifiers', []): res.append( stdid(id_.get('identifier'), id_.get('scheme'), q='alternateidentifier')) return res or missing def _get_personal_name(self, v, relator_code=None): ids = [] for scheme in [ 'gnd', 'orcid', ]: if v.get(scheme): ids.append((scheme, v[scheme])) return dict(personal_name=v.get('name'), affiliation=v.get('affiliation'), authority_record_control_number_or_standard_number=[ "({0}){1}".format(scheme, identifier) for (scheme, identifier) in ids ], relator_code=[relator_code] if relator_code else []) def get_main_entry_personal_name(self, o): """Get main_entry_personal_name.""" creators = o['metadata'].get('creators', []) if len(creators) > 0: v = creators[0] return self._get_personal_name(v) def get_added_entry_personal_name(self, o): """Get added_entry_personal_name.""" items = [] creators = o['metadata'].get('creators', []) if len(creators) > 1: for c in creators[1:]: items.append(self._get_personal_name(c)) contributors = o['metadata'].get('contributors', []) for c in contributors: items.append( self._get_personal_name(c, relator_code=self._map_contributortype( c.get('type')))) supervisors = o['metadata'].get('thesis', {}).get('supervisors', []) for s in supervisors: items.append(self._get_personal_name(s, relator_code='ths')) return items def _map_contributortype(self, type_): return current_app.config['DEPOSIT_CONTRIBUTOR_DATACITE2MARC'][type_] def get_added_entry_meeting_name(self, o): """Get added_entry_meeting_name.""" v = o['metadata'].get('meeting', {}) return [ dict( meeting_name_or_jurisdiction_name_as_entry_element=v.get( 'title'), location_of_meeting=v.get('place'), date_of_meeting=v.get('dates'), miscellaneous_information=v.get('acronym'), number_of_part_section_meeting=v.get('session'), name_of_part_section_of_a_work=v.get('session_part'), ) ] @post_dump(pass_many=True) def remove_empty_fields(self, data, many): """Dump + Remove empty fields.""" _filter_empty(data) return data
def test_function_with_uncallable_param(self): with pytest.raises(ValueError): fields.Function("uncallable")
class UserAdditionalSchema(Schema): lowername = fields.Function(lambda obj: obj.name.lower()) class Meta: additional = ('name', 'age', 'created', 'email')
def test_function_field_passed_serialize_only_is_dump_only(self, user): field = fields.Function(serialize=lambda obj: obj.name.upper()) assert field.dump_only is True
class TestValidation: def test_integer_with_validator(self): field = fields.Integer(validate=lambda x: 18 <= x <= 24) out = field.deserialize('20') assert out == 20 with pytest.raises(ValidationError): field.deserialize(25) @pytest.mark.parametrize('field', [ fields.Integer(validate=[lambda x: x <= 24, lambda x: 18 <= x]), fields.Integer(validate=( lambda x: x <= 24, lambda x: 18 <= x, )), fields.Integer(validate=validators_gen) ]) def test_integer_with_validators(self, field): out = field.deserialize('20') assert out == 20 with pytest.raises(ValidationError): field.deserialize(25) @pytest.mark.parametrize('field', [ fields.Float(validate=[lambda f: f <= 4.1, lambda f: f >= 1.0]), fields.Float(validate=( lambda f: f <= 4.1, lambda f: f >= 1.0, )), fields.Float(validate=validators_gen_float) ]) def test_float_with_validators(self, field): assert field.deserialize(3.14) with pytest.raises(ValidationError): field.deserialize(4.2) def test_string_validator(self): field = fields.String(validate=lambda n: len(n) == 3) assert field.deserialize('Joe') == 'Joe' with pytest.raises(ValidationError): field.deserialize('joseph') def test_function_validator(self): field = fields.Function(lambda d: d.name.upper(), validate=lambda n: len(n) == 3) assert field.deserialize('joe') with pytest.raises(ValidationError): field.deserialize('joseph') @pytest.mark.parametrize('field', [ fields.Function( lambda d: d.name.upper(), validate=[lambda n: len(n) == 3, lambda n: n[1].lower() == 'o']), fields.Function( lambda d: d.name.upper(), validate=(lambda n: len(n) == 3, lambda n: n[1].lower() == 'o')), fields.Function(lambda d: d.name.upper(), validate=validators_gen_str) ]) def test_function_validators(self, field): assert field.deserialize('joe') with pytest.raises(ValidationError): field.deserialize('joseph') def test_method_validator(self): class MethodSerializer(Schema): name = fields.Method('get_name', deserialize='get_name', validate=lambda n: len(n) == 3) def get_name(self, val): return val.upper() assert MethodSerializer(strict=True).load({'name': 'joe'}) with pytest.raises(ValidationError) as excinfo: MethodSerializer(strict=True).load({'name': 'joseph'}) assert 'Invalid value.' in str(excinfo) # Regression test for https://github.com/marshmallow-code/marshmallow/issues/269 def test_nested_data_is_stored_when_validation_fails(self): class SchemaA(Schema): x = fields.Integer() y = fields.Integer(validate=lambda n: n > 0) z = fields.Integer() class SchemaB(Schema): w = fields.Integer() n = fields.Nested(SchemaA) sch = SchemaB() data, errors = sch.load({'w': 90, 'n': {'x': 90, 'y': 89, 'z': None}}) assert 'z' in errors['n'] assert data == {'w': 90, 'n': {'x': 90, 'y': 89}} data, errors = sch.load({'w': 90, 'n': {'x': 90, 'y': -1, 'z': 180}}) assert 'y' in errors['n'] assert data == {'w': 90, 'n': {'x': 90, 'z': 180}}
def test_function_field_passed_deserialize_and_serialize_is_not_dump_only( self): field = fields.Function(serialize=lambda val: val.lower(), deserialize=lambda val: val.upper()) assert field.dump_only is False
def test_function_field_deserialization_is_noop_by_default(self): field = fields.Function(lambda x: None) # Default is noop assert field.deserialize('foo') == 'foo' assert field.deserialize(42) == 42
def test_function_field_passed_serialize(self, user): field = fields.Function(serialize=lambda obj: obj.name.upper()) assert "FOO" == field.serialize("key", user)
def test_deserialization_function_must_be_callable(self): with pytest.raises(ValueError): fields.Function(lambda x: None, deserialize='notvalid')
class BaseModelSchema(ModelSchema): created = fields.Function(serialize_datetime('created'))