Пример #1
0
 def test_function_field_passed_func(self, user):
     field = fields.Function(lambda obj: obj.name.upper())
     assert 'FOO' == field.serialize('key', user)
 def test_function_validator(self):
     field = fields.Function(lambda d: d.name.upper(),
                             validate=lambda n: len(n) == 3)
     assert field.deserialize('joe')
     with pytest.raises(ValidationError):
         field.deserialize('joseph')
 def test_function_field_deserialization_with_callable(self):
     field = fields.Function(lambda x: None,
                             deserialize=lambda val: val.upper())
     assert field.deserialize('foo') == 'FOO'
Пример #4
0
class TestValidation:

    def test_integer_with_validator(self):
        field = fields.Integer(validate=lambda x: 18 <= x <= 24)
        out = field.deserialize('20')
        assert out == 20
        with pytest.raises(ValidationError):
            field.deserialize(25)

    @pytest.mark.parametrize('field', [
        fields.Integer(validate=[lambda x: x <= 24, lambda x: 18 <= x]),
        fields.Integer(validate=(lambda x: x <= 24, lambda x: 18 <= x, )),
        fields.Integer(validate=validators_gen)
    ])
    def test_integer_with_validators(self, field):
        out = field.deserialize('20')
        assert out == 20
        with pytest.raises(ValidationError):
            field.deserialize(25)

    @pytest.mark.parametrize('field', [
        fields.Float(validate=[lambda f: f <= 4.1, lambda f: f >= 1.0]),
        fields.Float(validate=(lambda f: f <= 4.1, lambda f: f >= 1.0, )),
        fields.Float(validate=validators_gen_float)
    ])
    def test_float_with_validators(self, field):
        assert field.deserialize(3.14)
        with pytest.raises(ValidationError):
            field.deserialize(4.2)

    def test_string_validator(self):
        field = fields.String(validate=lambda n: len(n) == 3)
        assert field.deserialize('Joe') == 'Joe'
        with pytest.raises(ValidationError):
            field.deserialize('joseph')

    def test_function_validator(self):
        field = fields.Function(lambda d: d.name.upper(),
                                validate=lambda n: len(n) == 3)
        assert field.deserialize('joe')
        with pytest.raises(ValidationError):
            field.deserialize('joseph')

    @pytest.mark.parametrize('field', [
        fields.Function(lambda d: d.name.upper(),
            validate=[lambda n: len(n) == 3, lambda n: n[1].lower() == 'o']),
        fields.Function(lambda d: d.name.upper(),
            validate=(lambda n: len(n) == 3, lambda n: n[1].lower() == 'o')),
        fields.Function(lambda d: d.name.upper(),
            validate=validators_gen_str)
    ])
    def test_function_validators(self, field):
        assert field.deserialize('joe')
        with pytest.raises(ValidationError):
            field.deserialize('joseph')

    def test_method_validator(self):
        class MethodSerializer(Schema):
            name = fields.Method('get_name', deserialize='get_name',
                                      validate=lambda n: len(n) == 3)

            def get_name(self, val):
                return val.upper()
        assert MethodSerializer(strict=True).load({'name': 'joe'})
        with pytest.raises(ValidationError) as excinfo:
            MethodSerializer(strict=True).load({'name': 'joseph'})
        assert 'is False' in str(excinfo)
Пример #5
0
class NotificationSchema(Schema):
    created_at = fields.DateTime()
    date = fields.Function(lambda obj: str(obj.created_at.date()))
    body = fields.String()
    demande = fields.Nested(DemandeSchema)
Пример #6
0
class CityDistrictSchema(Schema):
    id = fields.Function(func=lambda x: str(x.id))

    class Meta:
        fields = ['id', 'title']
Пример #7
0
class RecordSchemaMarcXMLV1(Schema):
    """Schema for records in MARC."""

    control_number = fields.Method('get_id')

    def get_id(self, obj):
        pids = obj['metadata'].get('_pid')
        p = [p['value'] for p in pids if p['type'] == 'b2rec']
        return str(p[0])

    date_and_time_of_latest_transaction = fields.Function(
        lambda o: parse(o['updated']).strftime("%Y%m%d%H%M%S.0"))

    main_entry_personal_name = fields.Method('get_main_entry_personal_name')

    added_entry_personal_name = fields.Method('get_added_entry_personal_name')

    title_statement = fields.Function(
        lambda o: o['metadata']['titles'][0])

    publication_distribution_imprint = fields.Function(
        lambda o: dict(name_of_publisher_distributor=o['metadata'].get('publisher'),
                       date_of_publication_distribution=o['metadata'].get('publication_date')))

    media_type = fields.Function(
        lambda o: dict(media_type_term=[x['resource_type_general']
                                        for x in o['metadata'].get('resource_types', [])]))

    summary = fields.Function(
        lambda o: [dict(summary=x.get('description'))
                   for x in o['metadata'].get('descriptions', [])])

    study_program_information_note = fields.Function(
        lambda o: [dict(program_name=o['metadata'].get('disciplines', []))])

    terms_governing_use_and_reproduction_note = fields.Function(
        lambda o: dict(terms_governing_use_and_reproduction=
                       o['metadata'].get('license', {}).get('license'),
                       uniform_resource_identifier=
                       o['metadata'].get('license', {}).get('license_uri')))

    information_relating_to_copyright_status = fields.Function(
        lambda o: dict(copyright_status='open' if o['metadata']['open_access'] else 'closed'))

    language_note = fields.Function(
        lambda o: [dict(language_note=o['metadata'].get('language'))])

    index_term_uncontrolled = fields.Function(
        lambda o: [dict(uncontrolled_term=x) for x in o['metadata'].get('keywords', [])])

    other_standard_identifier = fields.Function(
        lambda o: [dict(standard_number_or_code=x['alternate_identifier'])
                   for x in o['metadata'].get('alternate_identifiers', [])])

    electronic_location_and_access = fields.Function(
        lambda o: [dict(uniform_resource_identifier=f['ePIC_PID'],
                        file_size=str(f['size']),
                        access_method="HTTP")
                   if f.get('ePIC_PID') else None
                   for f in o['metadata'].get('_files', [])])

    # Custom fields:

    embargo_date = fields.Raw(attribute='metadata.embargo_date')

    _oai = fields.Raw(attribute='metadata._oai')


    def get_main_entry_personal_name(self, o):
        creators = o['metadata'].get('creators', [])
        if len(creators) > 0:
            return dict(personal_name=creators[0]['creator_name'])
        return dict()


    def get_added_entry_personal_name(self, o):
        """Get added_entry_personal_name."""
        items = []
        creators = o['metadata'].get('creators', [])
        if len(creators) > 1:
            for c in creators[1:]:
                items.append(dict(personal_name=c['creator_name']))

        contributors = o['metadata'].get('contributors', [])
        for c in contributors:
            items.append(dict(personal_name=c['contributor_name']))

        return items

    @post_dump(pass_many=True)
    def remove_empty_fields(self, data, many):
        """Dump + Remove empty fields."""
        _filter_empty(data)
        return data
Пример #8
0
class WorkerSchema(Schema):

    id = fields.Integer()

    ean13 = fields.String()
    type = fields.Function(lambda obj: obj.type.name.swapcase())
Пример #9
0
class UserSchema(mm.Schema):
    email = fields.String()
    name = fields.Function(lambda u: f'{u["first_name"]} {u["last_name"]}')
    uid = fields.String()
    avatar_url = fields.Function(lambda user: url_for(
        'api.user_avatar', payload=avatar_payload_from_user_info(user)))
Пример #10
0
class UserSchema(mm.Schema):
    email = fields.String()
    name = fields.Function(lambda u: f'{u["first_name"]} {u["last_name"]}')
    initials = fields.Function(
        lambda u: f'{u["first_name"][0]} {u["last_name"][0]}')
    uid = fields.String()
Пример #11
0
class ServiceSchema(AutoSchema):
    _id = fields.Integer(attribute='id', dump_only=True)
    _rev = fields.String(default='', dump_only=True)
    owned = fields.Boolean(default=False)
    owner = PrimaryKeyRelatedField('username',
                                   dump_only=True,
                                   attribute='creator')
    port = fields.Integer(
        dump_only=True,
        strict=True,
        required=True,
        validate=[
            Range(min=0, error="The value must be greater than or equal to 0")
        ])  # Port is loaded via ports
    ports = MutableField(fields.Integer(
        strict=True,
        required=True,
        validate=[
            Range(min=0, error="The value must be greater than or equal to 0")
        ]),
                         fields.Method(deserialize='load_ports'),
                         required=True,
                         attribute='port')
    status = fields.String(default='open',
                           validate=OneOf(Service.STATUSES),
                           required=True,
                           allow_none=False)
    parent = fields.Integer(
        attribute='host_id')  # parent is not required for updates
    host_id = fields.Integer(attribute='host_id', dump_only=True)
    vulns = fields.Integer(attribute='vulnerability_count', dump_only=True)
    credentials = fields.Integer(attribute='credentials_count', dump_only=True)
    metadata = SelfNestedField(MetadataSchema())
    type = fields.Function(lambda obj: 'Service', dump_only=True)

    def load_ports(self, value):
        if not isinstance(value, list):
            raise ValidationError('ports must be a list')
        if len(value) != 1:
            raise ValidationError('ports must be a list with exactly one'
                                  'element')
        port = value.pop()
        if port < 0:
            raise ValidationError(
                'The value must be greater than or equal to 0')

        return str(port)

    @post_load
    def post_load_parent(self, data):
        """Gets the host_id from parent attribute. Pops it and tries to
        get a Host with that id in the corresponding workspace.
        """
        host_id = data.pop('host_id', None)
        if self.context['updating']:
            if host_id is None:
                # Partial update?
                return

            if host_id != self.context['object'].parent.id:
                raise ValidationError('Can\'t change service parent.')

        else:
            if not host_id:
                raise ValidationError(
                    'Parent id is required when creating a service.')

            try:
                data['host'] = Host.query.join(Workspace).filter(
                    Workspace.name == self.context['workspace_name'],
                    Host.id == host_id).one()
            except NoResultFound:
                raise ValidationError(
                    'Host with id {} not found'.format(host_id))

    class Meta:
        model = Service
        fields = ('id', '_id', 'status', 'parent', 'type', 'protocol',
                  'description', '_rev', 'owned', 'owner', 'credentials',
                  'vulns', 'name', 'version', '_id', 'port', 'ports',
                  'metadata', 'summary', 'host_id')
Пример #12
0
class PaymentTableSchema(BaseSchema):
    class WD(WorkReportSchema):
        work_days_list = fields.List(
            fields.Nested(DailyLogDefaultSchema(many=False)))
        leave_days_list = fields.List(
            fields.Nested(LeaveInfoSchema(many=False)))
        extra_work_days_list = fields.List(
            fields.Nested(ExtraWorkInfoSchema(many=False)))
        absent_days_list = fields.List(
            fields.Nested(DailyLogDefaultSchema(many=False)))
        shift_days_list = fields.List(
            fields.Nested(LeaveInfoSchema(many=False)))
        rest_days_list = fields.List(
            fields.Nested(DailyLogDefaultSchema(many=False)))

    class C(Schema):
        class OS(Schema):
            position = fields.String()
            position_level = fields.String()
            money = fields.Float()

        s_money = fields.Float()

        offer = fields.Nested(OS(many=False))

    class E(Schema):
        real_name = fields.String()
        welfare_rate = fields.String()

    class PLS(Schema):
        money = fields.Float()
        name = fields.String()

    career = fields.Nested(C(many=False))
    position_level = fields.Nested(PLS(many=False))
    year_month = fields.Int()
    engineer = fields.Nested(E(many=False))
    engineer_id = fields.Int()
    company_pay = fields.Float()
    income = fields.Float()
    tax = fields.Float()
    fee = fields.Float()
    welfare = fields.Float()
    amerce = fields.Float()
    status = fields.Function(lambda x: PaymentStatus.int2str(x.status))
    engineer_tax = fields.Float()
    engineer_get = fields.Float()
    break_up_fee = fields.Float()
    engineer_income_with_tax = fields.Float()
    finance_fee = fields.Float()
    hr_fee = fields.Float()
    service_fee_rate = fields.Float()
    employ_type = fields.Integer()
    out_duty_days = fields.Float()
    finance_rate = fields.Float()
    tax_fee_rate = fields.Float()
    tax_rate = fields.Float()
    use_hr_servce = fields.Integer()
    ware_fare = fields.Float()
    pm = fields.String()
    project = fields.String()
    work_report = fields.Nested(WD(many=False))
    station_salary = fields.Float()
    extra_salary = fields.Float()
    tax_free_rate = fields.Float()
Пример #13
0
class InterviewPmPutSchema(Schema):
    status = fields.Function(lambda x: x,
                             deserialize=lambda x: InterviewStatus.str2int(x))
Пример #14
0
 def test_function_field_does_not_swallow_attribute_error(self, user):
     def raise_error(obj):
         raise AttributeError()
     field = fields.Function(serialize=raise_error)
     with pytest.raises(AttributeError):
         field.serialize('key', user)
Пример #15
0
 def test_function_field_load_only(self):
     field = fields.Function(deserialize=lambda obj: None)
     assert field.load_only
Пример #16
0
class UserSchema(marshmallow.Schema):
    """Schema of a user to be used to extract API information.

    This class is a ``marshmallow`` schema which automatically gets its
    structure from the ``User`` class. Plus, we add some useful information
    or link. This schema is only used for administration listing.
    """

    isadmin = fields.Function(lambda user: user.is_admin())
    """ Wraps :py:meth:`collectives.models.user.User.is_admin`

    :type: boolean"""
    roles_uri = fields.Function(
        lambda user: url_for("administration.add_user_role", user_id=user.id))
    """ URI to role management page for this user

    :type: string
    """
    delete_uri = fields.Function(
        lambda user: url_for("administration.delete_user", user_id=user.id))
    """ URI to delete this user (WIP)

    :type: string
    """
    manage_uri = fields.Function(
        lambda user: url_for("administration.manage_user", user_id=user.id))
    """ URI to modify this user

    :type: string
    """
    profile_uri = fields.Function(
        lambda user: url_for("profile.show_user", user_id=user.id))
    """ URI to see user profile

    :type: string
    """
    leader_profile_uri = fields.Function(
        lambda user: url_for("profile.show_leader", leader_id=user.id)
        if user.can_create_events() else None)
    """ URI to see user profile

    :type: string
    """
    avatar_uri = fields.Function(avatar_url)
    """ URI to a resized version (30px) of user avatar

    :type: string
    """
    roles = fields.Function(
        lambda user: RoleSchema(many=True).dump(user.roles))
    """ List of roles of the User.

    Roles are encoded as JSON.

    :type: list(dict())"""
    full_name = fields.Function(lambda user: user.full_name())
    """ User full name

    :type: string"""
    class Meta:
        """Fields to expose"""

        fields = (
            "id",
            "mail",
            "isadmin",
            "enabled",
            "roles_uri",
            "avatar_uri",
            "manage_uri",
            "profile_uri",
            "delete_uri",
            "first_name",
            "last_name",
            "roles",
            "isadmin",
            "leader_profile_uri",
            "full_name",
        )
Пример #17
0
 def test_function_field_passed_uncallable_object(self):
     with pytest.raises(ValueError):
         fields.Function("uncallable")
Пример #18
0
class UserAdditionalSchema(Schema):
    lowername = fields.Function(lambda obj: obj.name.lower())

    class Meta:
        additional = ("name", "age", "created", "email")
Пример #19
0
class CategorySchema(ModelSchema):
    """Provides detailed information about a category, including all the children"""
    class Meta(ModelSchema.Meta):
        model = Category
        fields = ('id', 'name', 'children', 'parent_id', 'parent', 'level', 'event_count',
                  'location_count', 'resource_count', 'all_resource_count', 'study_count',
                  'color', '_links', 'last_updated', 'display_order')
    id = fields.Integer(required=False, allow_none=True)
    parent_id = fields.Integer(required=False, allow_none=True)
    children = ma.Nested(lambda: CategorySchema(), many=True, dump_only=True, exclude=('parent', 'color'))
    parent = ma.Nested(ParentCategorySchema, dump_only=True)
    level = fields.Function(lambda obj: obj.calculate_level() if isinstance(obj, Category) else 0, dump_only=True)
    event_count = fields.Method('get_event_count', dump_only=True)
    location_count = fields.Method('get_location_count', dump_only=True)
    resource_count = fields.Method('get_resource_count', dump_only=True)
    all_resource_count = fields.Method('get_all_resource_count', dump_only=True)
    study_count = fields.Method('get_study_count', dump_only=True)
    _links = ma.Hyperlinks({
        'self': ma.URLFor('api.categoryendpoint', id='<id>'),
        'collection': ma.URLFor('api.categorylistendpoint')
    })

    def get_event_count(self, obj):
        if obj is None:
            return missing
        query = db.session.query(ResourceCategory).filter(ResourceCategory.type == 'event')\
            .filter(ResourceCategory.category_id == obj.id)
        count_q = query.statement.with_only_columns([func.count()]).order_by(None)
        return query.session.execute(count_q).scalar()

    def get_location_count(self, obj):
        if obj is None:
            return missing
        query = db.session.query(ResourceCategory).filter(ResourceCategory.type == 'location')\
            .filter(ResourceCategory.category_id == obj.id)
        count_q = query.statement.with_only_columns([func.count()]).order_by(None)
        return query.session.execute(count_q).scalar()

    def get_resource_count(self, obj):
        if obj is None:
            return missing
        query = db.session.query(ResourceCategory).filter(ResourceCategory.type == 'resource')\
            .filter(ResourceCategory.category_id == obj.id)
        count_q = query.statement.with_only_columns([func.count()]).order_by(None)
        return query.session.execute(count_q).scalar()

    def get_all_resource_count(self, obj):
        if obj is None:
            return missing
        query = db.session.query(ResourceCategory).join(ResourceCategory.resource)\
            .filter(ResourceCategory.category_id == obj.id)
        count_q = query.statement.with_only_columns([func.count()]).order_by(None)
        return query.session.execute(count_q).scalar()

    def get_study_count(self, obj):
        if obj is None:
            return missing
        query = db.session.query(StudyCategory).join(StudyCategory.study)\
            .filter(StudyCategory.category_id == obj.id)
        count_q = query.statement.with_only_columns([func.count()]).order_by(None)
        return query.session.execute(count_q).scalar()
Пример #20
0
class VulnerabilityFilterSet(FilterSet):
    class Meta(FilterSetMeta):
        model = VulnerabilityWeb  # It has all the fields
        # TODO migration: Check if we should add fields owner,
        # command, impact, issuetracker, tags, date, host
        # evidence, policy violations, hostnames
        fields = (
            "id", "status", "website", "pname", "query", "path", "service",
            "data", "severity", "confirmed", "name", "request", "response",
            "parameters", "params", "resolution", "ease_of_resolution",
            "description", "command_id", "target", "creator", "method",
            "easeofresolution", "query_string", "parameter_name", "service_id",
            "status_code"
        )

        strict_fields = (
            "severity", "confirmed", "method", "status", "easeofresolution",
            "ease_of_resolution", "service_id",
        )

        default_operator = CustomILike
        # next line uses dict comprehensions!
        column_overrides = {
            field: _strict_filtering for field in strict_fields
        }
        operators = (CustomILike, operators.Equal)
    id = IDFilter(fields.Int())
    target = TargetFilter(fields.Str())
    type = TypeFilter(fields.Str(validate=[OneOf(['Vulnerability',
                                                  'VulnerabilityWeb'])]))
    creator = CreatorFilter(fields.Str())
    service = ServiceFilter(fields.Str())
    severity = Filter(SeverityField())
    easeofresolution = Filter(fields.String(
        attribute='ease_of_resolution',
        validate=OneOf(Vulnerability.EASE_OF_RESOLUTIONS),
        allow_none=True))
    pname = Filter(fields.String(attribute='parameter_name'))
    query = Filter(fields.String(attribute='query_string'))
    status_code = StatusCodeFilter(fields.Int())
    params = Filter(fields.String(attribute='parameters'))
    status = Filter(fields.Function(
        deserialize=lambda val: 'open' if val == 'opened' else val,
        validate=OneOf(Vulnerability.STATUSES + ['opened'])
    ))
    hostnames = HostnamesFilter(fields.Str())
    confirmed = Filter(fields.Boolean())

    def filter(self):
        """Generate a filtered query from request parameters.

        :returns: Filtered SQLALchemy query
        """
        # TODO migration: this can became a normal filter instead of a custom
        # one, since now we can use creator_command_id
        command_id = request.args.get('command_id')
        query = super(VulnerabilityFilterSet, self).filter()

        if command_id:
            # query = query.filter(CommandObject.command_id == int(command_id))
            query = query.filter(VulnerabilityGeneric.creator_command_id ==
                                 int(command_id))  # TODO migration: handle invalid int()
        return query
Пример #21
0
class RecordSchemaMARC21(Schema):
    """Schema for records in MARC."""

    control_number = fields.Function(lambda o: str(o['metadata'].get('recid')))

    date_and_time_of_latest_transaction = fields.Function(
        lambda obj: parse(obj['updated']).strftime("%Y%m%d%H%M%S.0"))

    information_relating_to_copyright_status = fields.Function(
        lambda o: dict(copyright_status=o['metadata']['access_right']))

    index_term_uncontrolled = fields.Function(lambda o: [
        dict(uncontrolled_term=kw) for kw in o['metadata'].get('keywords', [])
    ])

    subject_added_entry_topical_term = fields.Method(
        'get_subject_added_entry_topical_term')

    terms_governing_use_and_reproduction_note = fields.Function(
        lambda o: dict(uniform_resource_identifier=o['metadata'].get(
            'license', {}).get('url'),
                       terms_governing_use_and_reproduction=o['metadata'].get(
                           'license', {}).get('title')))

    title_statement = fields.Function(
        lambda o: dict(title=o['metadata'].get('title')))

    general_note = fields.Function(
        lambda o: dict(general_note=o['metadata'].get('notes')))

    information_relating_to_copyright_status = fields.Function(
        lambda o: dict(copyright_status=o['metadata'].get('access_right')))

    publication_distribution_imprint = fields.Method(
        'get_publication_distribution_imprint')

    funding_information_note = fields.Function(lambda o: [
        dict(text_of_note=v.get('title'), grant_number=v.get('code'))
        for v in o['metadata'].get('grants', [])
    ])

    other_standard_identifier = fields.Method('get_other_standard_identifier')

    added_entry_meeting_name = fields.Method('get_added_entry_meeting_name')

    main_entry_personal_name = fields.Method('get_main_entry_personal_name')

    added_entry_personal_name = fields.Method('get_added_entry_personal_name')

    summary = fields.Function(
        lambda o: dict(summary=o['metadata'].get('description')))

    host_item_entry = fields.Method('get_host_item_entry')

    dissertation_note = fields.Function(
        lambda o: dict(name_of_granting_institution=o['metadata'].get(
            'thesis', {}).get('university')))

    language_code = fields.Function(
        lambda o: dict(language_code_of_text_sound_track_or_separate_title=\
                       o['metadata'].get('language')))

    # Custom
    # ======

    resource_type = fields.Raw(attribute='metadata.resource_type')

    communities = fields.Raw(attribute='metadata.communities')

    references = fields.Raw(attribute='metadata.references')

    embargo_date = fields.Raw(attribute='metadata.embargo_date')

    journal = fields.Raw(attribute='metadata.journal')

    #_oai = fields.Raw(attribute='metadata._oai')

    _files = fields.Method('get_files')

    leader = fields.Method('get_leader')

    conference_url = fields.Raw(attribute='metadata.meeting.url')

    def get_leader(self, o):
        """Return the leader information."""
        rt = o['metadata']['resource_type']['type']
        rec_types = {
            'image': 'two-dimensional_nonprojectable_graphic',
            'video': 'projected_medium',
            'dataset': 'computer_file',
            'software': 'computer_file',
        }
        type_of_record = rec_types[rt] if rt in rec_types \
            else 'language_material'
        res = {
            'record_length': '00000',
            'record_status': 'new',
            'type_of_record': type_of_record,
            'bibliographic_level': 'monograph_item',
            'type_of_control': 'no_specified_type',
            'character_coding_scheme': 'marc-8',
            'indicator_count': 2,
            'subfield_code_count': 2,
            'base_address_of_data': '00000',
            'encoding_level': 'unknown',
            'descriptive_cataloging_form': 'unknown',
            'multipart_resource_record_level':
            'not_specified_or_not_applicable',
            'length_of_the_length_of_field_portion': 4,
            'length_of_the_starting_character_position_portion': 5,
            'length_of_the_implementation_defined_portion': 0,
            'undefined': 0,
        }
        return res

    def get_files(self, o):
        """Get the files provided the record is open access."""
        if o['metadata']['access_right'] != 'open':
            return missing
        res = []
        for f in o['metadata'].get('_files', []):
            res.append(
                dict(
                    uri=u'{0}/record/{1}/files/{2}'.format(
                        current_app.config.get('THEME_SITEURL', 'zenodo.org'),
                        o['metadata'].get('recid', ''), f['key']),
                    size=f['size'],
                    checksum=f['checksum'],
                    type=f['type'],
                ))
        return res or missing

    def get_host_item_entry(self, o):
        """Get host items."""
        res = []
        for v in o['metadata'].get('related_identifiers', []):
            res.append(
                dict(
                    main_entry_heading=v.get('identifier'),
                    relationship_information=v.get('relation'),
                    note=v.get('scheme'),
                ))

        imprint = o['metadata'].get('imprint', {})
        part_of = o['metadata'].get('part_of', {})
        if part_of and imprint:
            res.append(
                dict(
                    main_entry_heading=imprint.get('place'),
                    edition=imprint.get('publisher'),
                    title=part_of.get('title'),
                    related_parts=part_of.get('pages'),
                    international_standard_book_number=imprint.get('isbn'),
                ))

        return res or missing

    def get_publication_distribution_imprint(self, o):
        """Get publication date and imprint."""
        res = []

        pubdate = o['metadata'].get('publication_date')
        if pubdate:
            res.append(dict(date_of_publication_distribution=pubdate))

        imprint = o['metadata'].get('imprint')
        part_of = o['metadata'].get('part_of')
        if not part_of and imprint:
            res.append(
                dict(
                    place_of_publication_distribution=imprint.get('place'),
                    name_of_publisher_distributor=imprint.get('publisher'),
                    date_of_publication_distribution=pubdate,
                ))

        return res or missing

    def get_subject_added_entry_topical_term(self, o):
        """Get licenses and subjects."""
        res = []

        license = o['metadata'].get('license', {}).get('id')
        if license:
            res.append(
                dict(
                    topical_term_or_geographic_name_entry_element='cc-by',
                    source_of_heading_or_term='opendefinition.org',
                    level_of_subject='Primary',
                    thesaurus='Source specified in subfield $2',
                ))

        def _subject(term, id_, scheme):
            return dict(
                topical_term_or_geographic_name_entry_element=term,
                authority_record_control_number_or_standard_number=(
                    '({0}){1}'.format(scheme, id_)),
                level_of_subject='Primary',
            )

        for s in o['metadata'].get('subjects', []):
            res.append(
                _subject(
                    s.get('term'),
                    s.get('identifier'),
                    s.get('scheme'),
                ))

        return res or missing

    def get_other_standard_identifier(self, o):
        """Get other standard identifiers."""
        res = []

        def stdid(pid, scheme, q=None):
            return dict(
                standard_number_or_code=pid,
                source_of_number_or_code=scheme,
                qualifying_information=q,
            )

        m = o['metadata']
        if m.get('doi'):
            res.append(stdid(m['doi'], 'doi'))

        for id_ in m.get('alternate_identifiers', []):
            res.append(
                stdid(id_.get('identifier'),
                      id_.get('scheme'),
                      q='alternateidentifier'))

        return res or missing

    def _get_personal_name(self, v, relator_code=None):
        ids = []
        for scheme in [
                'gnd',
                'orcid',
        ]:
            if v.get(scheme):
                ids.append((scheme, v[scheme]))

        return dict(personal_name=v.get('name'),
                    affiliation=v.get('affiliation'),
                    authority_record_control_number_or_standard_number=[
                        "({0}){1}".format(scheme, identifier)
                        for (scheme, identifier) in ids
                    ],
                    relator_code=[relator_code] if relator_code else [])

    def get_main_entry_personal_name(self, o):
        """Get main_entry_personal_name."""
        creators = o['metadata'].get('creators', [])
        if len(creators) > 0:
            v = creators[0]
            return self._get_personal_name(v)

    def get_added_entry_personal_name(self, o):
        """Get added_entry_personal_name."""
        items = []
        creators = o['metadata'].get('creators', [])
        if len(creators) > 1:
            for c in creators[1:]:
                items.append(self._get_personal_name(c))

        contributors = o['metadata'].get('contributors', [])
        for c in contributors:
            items.append(
                self._get_personal_name(c,
                                        relator_code=self._map_contributortype(
                                            c.get('type'))))

        supervisors = o['metadata'].get('thesis', {}).get('supervisors', [])
        for s in supervisors:
            items.append(self._get_personal_name(s, relator_code='ths'))

        return items

    def _map_contributortype(self, type_):
        return current_app.config['DEPOSIT_CONTRIBUTOR_DATACITE2MARC'][type_]

    def get_added_entry_meeting_name(self, o):
        """Get added_entry_meeting_name."""
        v = o['metadata'].get('meeting', {})
        return [
            dict(
                meeting_name_or_jurisdiction_name_as_entry_element=v.get(
                    'title'),
                location_of_meeting=v.get('place'),
                date_of_meeting=v.get('dates'),
                miscellaneous_information=v.get('acronym'),
                number_of_part_section_meeting=v.get('session'),
                name_of_part_section_of_a_work=v.get('session_part'),
            )
        ]

    @post_dump(pass_many=True)
    def remove_empty_fields(self, data, many):
        """Dump + Remove empty fields."""
        _filter_empty(data)
        return data
Пример #22
0
 def test_function_with_uncallable_param(self):
     with pytest.raises(ValueError):
         fields.Function("uncallable")
Пример #23
0
class UserAdditionalSchema(Schema):
    lowername = fields.Function(lambda obj: obj.name.lower())

    class Meta:
        additional = ('name', 'age', 'created', 'email')
Пример #24
0
 def test_function_field_passed_serialize_only_is_dump_only(self, user):
     field = fields.Function(serialize=lambda obj: obj.name.upper())
     assert field.dump_only is True
class TestValidation:
    def test_integer_with_validator(self):
        field = fields.Integer(validate=lambda x: 18 <= x <= 24)
        out = field.deserialize('20')
        assert out == 20
        with pytest.raises(ValidationError):
            field.deserialize(25)

    @pytest.mark.parametrize('field', [
        fields.Integer(validate=[lambda x: x <= 24, lambda x: 18 <= x]),
        fields.Integer(validate=(
            lambda x: x <= 24,
            lambda x: 18 <= x,
        )),
        fields.Integer(validate=validators_gen)
    ])
    def test_integer_with_validators(self, field):
        out = field.deserialize('20')
        assert out == 20
        with pytest.raises(ValidationError):
            field.deserialize(25)

    @pytest.mark.parametrize('field', [
        fields.Float(validate=[lambda f: f <= 4.1, lambda f: f >= 1.0]),
        fields.Float(validate=(
            lambda f: f <= 4.1,
            lambda f: f >= 1.0,
        )),
        fields.Float(validate=validators_gen_float)
    ])
    def test_float_with_validators(self, field):
        assert field.deserialize(3.14)
        with pytest.raises(ValidationError):
            field.deserialize(4.2)

    def test_string_validator(self):
        field = fields.String(validate=lambda n: len(n) == 3)
        assert field.deserialize('Joe') == 'Joe'
        with pytest.raises(ValidationError):
            field.deserialize('joseph')

    def test_function_validator(self):
        field = fields.Function(lambda d: d.name.upper(),
                                validate=lambda n: len(n) == 3)
        assert field.deserialize('joe')
        with pytest.raises(ValidationError):
            field.deserialize('joseph')

    @pytest.mark.parametrize('field', [
        fields.Function(
            lambda d: d.name.upper(),
            validate=[lambda n: len(n) == 3, lambda n: n[1].lower() == 'o']),
        fields.Function(
            lambda d: d.name.upper(),
            validate=(lambda n: len(n) == 3, lambda n: n[1].lower() == 'o')),
        fields.Function(lambda d: d.name.upper(), validate=validators_gen_str)
    ])
    def test_function_validators(self, field):
        assert field.deserialize('joe')
        with pytest.raises(ValidationError):
            field.deserialize('joseph')

    def test_method_validator(self):
        class MethodSerializer(Schema):
            name = fields.Method('get_name',
                                 deserialize='get_name',
                                 validate=lambda n: len(n) == 3)

            def get_name(self, val):
                return val.upper()

        assert MethodSerializer(strict=True).load({'name': 'joe'})
        with pytest.raises(ValidationError) as excinfo:
            MethodSerializer(strict=True).load({'name': 'joseph'})

        assert 'Invalid value.' in str(excinfo)

    # Regression test for https://github.com/marshmallow-code/marshmallow/issues/269
    def test_nested_data_is_stored_when_validation_fails(self):
        class SchemaA(Schema):
            x = fields.Integer()
            y = fields.Integer(validate=lambda n: n > 0)
            z = fields.Integer()

        class SchemaB(Schema):
            w = fields.Integer()
            n = fields.Nested(SchemaA)

        sch = SchemaB()

        data, errors = sch.load({'w': 90, 'n': {'x': 90, 'y': 89, 'z': None}})
        assert 'z' in errors['n']
        assert data == {'w': 90, 'n': {'x': 90, 'y': 89}}

        data, errors = sch.load({'w': 90, 'n': {'x': 90, 'y': -1, 'z': 180}})
        assert 'y' in errors['n']
        assert data == {'w': 90, 'n': {'x': 90, 'z': 180}}
Пример #26
0
 def test_function_field_passed_deserialize_and_serialize_is_not_dump_only(
         self):
     field = fields.Function(serialize=lambda val: val.lower(),
                             deserialize=lambda val: val.upper())
     assert field.dump_only is False
 def test_function_field_deserialization_is_noop_by_default(self):
     field = fields.Function(lambda x: None)
     # Default is noop
     assert field.deserialize('foo') == 'foo'
     assert field.deserialize(42) == 42
Пример #28
0
 def test_function_field_passed_serialize(self, user):
     field = fields.Function(serialize=lambda obj: obj.name.upper())
     assert "FOO" == field.serialize("key", user)
 def test_deserialization_function_must_be_callable(self):
     with pytest.raises(ValueError):
         fields.Function(lambda x: None, deserialize='notvalid')
Пример #30
0
class BaseModelSchema(ModelSchema):
    created = fields.Function(serialize_datetime('created'))