예제 #1
0
 def inner(values):
     if len(values) != 1:
         raise RESTValidationError(
             errors=[FieldError(label, 'Multiple values specified.')])
     term_value = choices.get(values[0])
     if not term_value:
         raise RESTValidationError(
             errors=[FieldError(
                 label, 'Allowed values: [{}]'.format(', '.join(choices)))])
     return Q('term', **{field: term_value})
예제 #2
0
    def inner(values):
        terms = current_custom_metadata.terms
        available_terms = current_custom_metadata.available_vocabulary_set

        must_conditions = []

        for value in values:
            # Matches this:
            #   [vocabulary:term]:value
            parsed = re.match(r'\[([-\w]+\:[-\w]+)\]\:(.+)', value)
            if not parsed:
                raise RESTValidationError(
                    errors=[FieldError(
                        field, 'The parameter should have the format: '
                               'custom=[field_name]:field_value.')])

            search_key, search_value = parsed.groups()

            if search_key not in available_terms:
                raise RESTValidationError(
                    errors=[FieldError(
                        field, 'The "{}" term is not supported.'
                        .format(search_key))])

            # TODO: check if the search value has the correct type
            # for now we have only 'keyword' and 'text'

            # TODO: move this to a central place
            # get the elasticsearch custom field name
            custom_fields_mapping = dict(
                keyword='custom_keywords',
                text='custom_text'
            )

            custom_type = terms[search_key]['term_type']
            es_field = custom_fields_mapping[custom_type]

            must_conditions.append({
                'nested': {
                    'path': es_field,
                    # 'score_mode': 'avg',
                    'query': {
                        'bool': {'must': [
                            {'match': {es_field + '.key': search_key}},
                            {'match': {es_field + '.value': search_value}}
                            # TODO: in the future also filter ".community"
                        ]}
                    }
                }
            })
        return Q('bool', must=must_conditions)
예제 #3
0
    def inner(values):
        if len(values) != 1 or values[0].count('--') != 1 or values[0] == '--':
            raise RESTValidationError(
                errors=[FieldError(field, 'Invalid range format.')])

        range_ends = values[0].split('--')
        range_args = dict()

        ineq_opers = [{'strict': 'gt', 'nonstrict': 'gte'},
                      {'strict': 'lt', 'nonstrict': 'lte'}]
        date_maths = [start_date_math, end_date_math]

        # Add the proper values to the dict
        for (range_end, strict, opers,
             date_math) in zip(range_ends, ['>', '<'], ineq_opers, date_maths):

            if range_end != '':
                # If first char is '>' for start or '<' for end
                if range_end[0] == strict:
                    dict_key = opers['strict']
                    range_end = range_end[1:]
                else:
                    dict_key = opers['nonstrict']

                if date_math:
                    range_end = '{0}||{1}'.format(range_end, date_math)

                range_args[dict_key] = range_end

        args = kwargs.copy()
        args.update(range_args)

        return Range(**{field: args})
예제 #4
0
def search_factory(self, search, query_parser=None):
    """Parse query using elasticsearch DSL query.

    :param self: REST view.
    :param search: Elastic search DSL search instance.
    :returns: Tuple with search instance and URL arguments.
    """
    from invenio_records_rest.facets import default_facets_factory
    from invenio_records_rest.sorter import default_sorter_factory
    search_index = search._index[0]

    # TODO: make "scheme" optional?
    for field in ('id', 'scheme', 'relation'):
        if field not in request.values:
            raise RESTValidationError(
                errors=[FieldError(field, 'Required field.')])

    search, urlkwargs = default_facets_factory(search, search_index)
    search, sortkwargs = default_sorter_factory(search, search_index)
    for key, value in sortkwargs.items():
        urlkwargs.add(key, value)

    # Apply 'identity' grouping by default
    if 'groupBy' not in request.values:
        search = search.filter(Q('term', Grouping='identity'))
        urlkwargs['groupBy'] = 'identity'

    # Exclude the identifiers by which the search was made (large aggregate)
    search = search.source(exclude=['*.SearchIdentifier'])
    return search, urlkwargs
예제 #5
0
def default_role_json_patch_loader(role=None):
    """Create JSON PATCH data loaders for role modifications.

    :param role: the modified role.
    :returns: a JSON corresponding to the patched role.
    """
    data = request.get_json(force=True)
    if data is None:
        abort(400)
    modified_fields = {
        cmd['path'][1:]
        for cmd in data
        if 'path' in cmd and 'op' in cmd and cmd['op'] != 'test'
    }
    errors = [
        FieldError(field, 'Unknown or immutable field {}.'.format(field))
        for field in modified_fields.difference(_role_fields)
    ]
    if len(errors) > 0:
        raise RESTValidationError(errors=errors)

    original = {'name': role.name, 'description': role.description}
    try:
        patched = apply_patch(original, data)
    except (JsonPatchException, JsonPointerException):
        raise PatchJSONFailureRESTError()
    return patched
예제 #6
0
    def json_patch_loader(user=None):
        data = request.get_json(force=True)
        if data is None:
            abort(400)
        modified_fields = {
            cmd['path'] for cmd in data
            if 'path' in cmd and 'op' in cmd and cmd['op'] != 'test'
        }
        errors = [
            FieldError(field, 'Unknown field {}.'.format(field))
            for field in _fields_with_profile.intersection(modified_fields)
        ]
        if len(errors) > 0:
            raise RESTValidationError(errors=errors)

        original = {
            'email': user.email, 'active': user.active, 'password': None
        }
        # if invenio-userprofiles is loaded add profile's fields
        if 'full_name' in fields:
            original.update({
                'full_name': user.profile.full_name,
                'username': user.profile.username
            })
        try:
            patched = apply_patch(original, data)
        except (JsonPatchException, JsonPointerException):
            raise PatchJSONFailureRESTError()
        if patched['password'] is None:
            del patched['password']

        if 'full_name' in fields:
            _fix_profile(patched)
        return patched
예제 #7
0
 def json_loader(**kwargs):
     """Default data loader when Invenio Userprofiles is not installed."""
     data = request.get_json(force=True)
     for key in data:
         if key not in allowed_fields:
             raise RESTValidationError(
                 errors=[FieldError(key, 'Unknown field {}'.format(key))])
     return data
예제 #8
0
 def handle_error(self, error, *args, **kwargs):
     """Handle errors during parsing."""
     if isinstance(error, ValidationError):
         _errors = []
         for field, messages in error.messages.items():
             _errors.extend([FieldError(field, msg) for msg in messages])
         raise RESTValidationError(errors=_errors)
     super(FlaskParser, self).handle_error(error, *args, **kwargs)
예제 #9
0
    def inner(values):
        if len(values) != 1:
            raise RESTValidationError(
                errors=[FieldError(name, 'Only one parameter is allowed.')])
        values = [value.strip() for value in values[0].split(',')]
        if len(values) != 4:
            raise RESTValidationError(errors=[
                FieldError(
                    name,
                    'Invalid bounds: four comma-separated numbers required. '
                    'Example: 143.37158,-38.99357,146.90918,-37.35269')
            ])

        try:
            bottom_left_lon = Decimal(values[0])
            bottom_left_lat = Decimal(values[1])
            top_right_lon = Decimal(values[2])
            top_right_lat = Decimal(values[3])
        except InvalidOperation:
            raise RESTValidationError(
                errors=[FieldError(name, 'Invalid number in bounds.')])
        try:
            if not (-90 <= bottom_left_lat <= 90) or \
                    not (-90 <= top_right_lat <= 90):
                raise RESTValidationError(errors=[
                    FieldError(name, 'Latitude must be between -90 and 90.')
                ])
            if not (-180 <= bottom_left_lon <= 180) or \
                    not (-180 <= top_right_lon <= 180):
                raise RESTValidationError(errors=[
                    FieldError(name, 'Longitude must be between -180 and 180.')
                ])
            if top_right_lat <= bottom_left_lat:
                raise RESTValidationError(errors=[
                    FieldError(
                        name, 'Top-right latitude must be greater than '
                        'bottom-left latitude.')
                ])
        except InvalidOperation:  # comparison with "NaN" raises exception
            raise RESTValidationError(errors=[
                FieldError(name,
                           'Invalid number: "NaN" is not a permitted value.')
            ])

        query = {
            field: {
                'top_right': {
                    'lat': top_right_lat,
                    'lon': top_right_lon,
                },
                'bottom_left': {
                    'lat': bottom_left_lat,
                    'lon': bottom_left_lon,
                }
            }
        }

        if type:
            query['type'] = type
        return Q('geo_bounding_box', **query)
예제 #10
0
def account_json_loader(**kwargs):
    """Accounts REST API data loader for JSON input."""
    data = request.get_json(force=True)
    for key in data:
        # only "active" field is immutable
        if key != 'active':
            raise RESTValidationError(
                errors=[FieldError(key, 'Field {} is immutable'.format(key))])
    return data
예제 #11
0
    def serialize(pid, record, links_factory=None):
        """Serialize a single record and persistent identifier.

        :param pid: Persistent identifier instance.
        :param record: Record instance.
        :param links_factory: Factory function for record links.
        """
        if record['$schema'] != Video.get_record_schema():
            raise RESTValidationError(
                errors=[FieldError(str(record.id), 'Unsupported format')])
        return VTT(record=record).format()
예제 #12
0
파일: loaders.py 프로젝트: haatveit/b2share
def check_patch_input_loader(record, immutable_paths):
    data = request.get_json(force=True)
    if data is None:
        abort(400)
    modified_fields = {cmd['path'] for cmd in data
                       if 'path' in cmd and 'op' in cmd and cmd['op'] != 'test'}
    errors = [FieldError(field, 'The field "{}" is immutable.'.format(field))
              for field in immutable_paths.intersection(modified_fields)]
    if len(errors) > 0:
        raise RESTValidationError(errors=errors)
    return data
예제 #13
0
    def json_patch_loader(user=None):
        """JSON patch loader.

        :param user: the modified account.
        :returns: a JSON corresponding to the patched account.
        """
        data = request.get_json(force=True)
        if data is None:
            abort(400)
        modified_fields = {
            cmd['path'][1:]
            for cmd in data
            if 'path' in cmd and 'op' in cmd and cmd['op'] != 'test'
        }
        errors = [
            FieldError(field, 'Unknown or immutable field {}.'.format(field))
            for field in modified_fields.difference(fields)
        ]
        if len(errors) > 0:
            raise RESTValidationError(errors=errors)

        original = {
            'email': user.email,
            'active': user.active,
            'password': None
        }
        # if invenio-userprofiles is loaded add profile's fields
        if 'full_name' in fields:
            original.update({
                'full_name': user.profile.full_name,
                'username': user.profile.username
            })
        try:
            patched = apply_patch(original, data)
        except (JsonPatchException, JsonPointerException):
            raise PatchJSONFailureRESTError()
        if patched['password'] is None:
            del patched['password']

        if 'full_name' in fields:
            _fix_profile(patched)
        return patched
예제 #14
0
def account_json_patch_loader(user=None, **kwargs):
    """Accounts REST API data loader for JSON Patch input."""
    data = request.get_json(force=True)
    if data is None:
        abort(400)
    modified_fields = {
        cmd['path'][1:]
        for cmd in data
        if 'path' in cmd and 'op' in cmd and cmd['op'] != 'test'
    }
    errors = [
        FieldError(field, 'Unknown or immutable field {}.'.format(field))
        # only "active" field is immutable
        for field in modified_fields if field != 'active'
    ]
    if len(errors) > 0:
        raise RESTValidationError(errors=errors)

    original = {'active': user.active}
    try:
        patched = apply_patch(original, data)
    except (JsonPatchException, JsonPointerException):
        raise PatchJSONFailureRESTError()
    return patched
예제 #15
0
def _abort(message, field=None, status=None):
    if field:
        raise RESTValidationError([FieldError(field, message)])
    raise RESTValidationError(description=message)
예제 #16
0
 def validation_error(error):
     """Catch validation errors."""
     return RESTValidationError().get_response()
예제 #17
0
    def inner(values):
        terms = current_custom_metadata.terms
        available_terms = current_custom_metadata.available_vocabulary_set
        conditions = []

        for value in values:
            # Matches this:
            #   [vocabulary:term]:[value]
            parsed = re.match(
                r'^\[(?P<key>[-\w]+\:[-\w]+)\]\:\[(?P<val>.+)\]$', value)
            if not parsed:
                raise RESTValidationError(errors=[
                    FieldError(
                        field, 'The parameter should have the format: '
                        'custom=[term]:[value].')
                ])

            parsed = parsed.groupdict()
            search_key = parsed['key']
            search_value = parsed['val']

            if search_key not in available_terms:
                raise RESTValidationError(errors=[
                    FieldError(
                        field, u'The "{}" term is not supported.'.format(
                            search_key))
                ])

            custom_fields_mapping = dict(
                keyword='custom_keywords',
                text='custom_text',
                relationship='custom_relationships',
            )

            term_type = terms[search_key]['type']
            es_field = custom_fields_mapping[term_type]

            nested_clauses = [
                {
                    'term': {
                        '{}.key'.format(es_field): search_key
                    }
                },
            ]

            if term_type in ('text', 'keyword'):
                nested_clauses.append({
                    'query_string': {
                        'fields': ['{}.value'.format(es_field)],
                        'query': search_value,
                    }
                })
            elif term_type == 'relationship':
                if ':' not in search_value:
                    raise RESTValidationError(errors=[
                        FieldError(field, (
                            'Relatinship terms serach values should '
                            'follow the format "<sub>:<obj>".'))
                    ])

                sub, obj = search_value.split(':', 1)
                if sub:
                    nested_clauses.append({
                        'query_string': {
                            'fields': [es_field + '.subject'],
                            'query': sub
                        }
                    })
                if obj:
                    nested_clauses.append({
                        'query_string': {
                            'fields': [es_field + '.object'],
                            'query': obj
                        }
                    })

            conditions.append({
                'nested': {
                    'path': es_field,
                    'query': {
                        'bool': {
                            'must': nested_clauses
                        }
                    },
                }
            })
        return Q('bool', must=conditions)
예제 #18
0
 def inner(values):
     if len(values) != 1:
         raise RESTValidationError(
             errors=[FieldError(label, 'Multiple values specified.')])
     return Q('nested', path=path, query=Range(**{field: {op: values[0]}}))
예제 #19
0
 def test_validation_error():
     raise RESTValidationError(
         errors=[FieldError('myfield', 'mymessage', code=10)])