Exemple #1
0
def _combine_projection_spec(filter_list, original_filter, prefix=''):
    """Re-format a projection fields spec into a nested dictionary.

    e.g: ['a', 'b.c', 'b.d'] => {'a': 1, 'b': {'c': 1, 'd': 1}}
    """
    if not isinstance(filter_list, list):
        return filter_list

    filter_dict = collections.OrderedDict()

    for key in filter_list:
        field, separator, subkey = key.partition('.')
        if not separator:
            if isinstance(filter_dict.get(field), list):
                other_key = field + '.' + filter_dict[field][0]
                raise OperationFailure(
                    'Invalid $project :: caused by :: specification contains two conflicting paths.'
                    ' Cannot specify both %s and %s: %s' %
                    (repr(prefix + field), repr(prefix + other_key),
                     original_filter))
            filter_dict[field] = 1
            continue
        if not isinstance(filter_dict.get(field, []), list):
            raise OperationFailure(
                'Invalid $project :: caused by :: specification contains two conflicting paths.'
                ' Cannot specify both %s and %s: %s' %
                (repr(prefix + field), repr(prefix + key), original_filter))
        filter_dict[field] = filter_dict.get(field, []) + [subkey]

    return collections.OrderedDict(
        (k,
         _combine_projection_spec(
             v, original_filter, prefix='%s%s.' % (prefix, k)))
        for k, v in six.iteritems(filter_dict))
    def _handle_array_operator(self, operator, value):
        if operator == '$size':
            if isinstance(value, list):
                if len(value) != 1:
                    raise OperationFailure('Expression $size takes exactly 1 arguments. '
                                           '%d were passed in.' % len(value))
                value = value[0]
            array_value = self.parse(value)
            if not isinstance(array_value, list):
                raise OperationFailure('The argument to $size must be an array, '
                                       'but was of type: %s' % type(array_value))
            return len(array_value)

        if operator == '$filter':
            if not isinstance(value, dict):
                raise OperationFailure('$filter only supports an object as its argument')
            extra_params = set(value) - {'input', 'cond', 'as'}
            if extra_params:
                raise OperationFailure('Unrecognized parameter to $filter: %s' % extra_params.pop())
            missing_params = {'input', 'cond'} - set(value)
            if missing_params:
                raise OperationFailure("Missing '%s' parameter to $filter" % missing_params.pop())

            input_array = self.parse(value['input'])
            fieldname = value.get('as', 'this')
            cond = value['cond']
            return [
                item for item in input_array
                if _Parser(self._doc_dict, dict(self._user_vars, **{fieldname: item})).parse(cond)
            ]

        raise NotImplementedError(
            "Although '%s' is a valid array operator for the "
            'aggregation pipeline, it is currently not implemented '
            'in Mongomock.' % operator)
Exemple #3
0
def _handle_count_stage(in_collection, database, options):
    if not isinstance(options, str) or options == '':
        raise OperationFailure('the count field must be a non-empty string')
    elif options.startswith('$'):
        raise OperationFailure('the count field cannot be a $-prefixed path')
    elif '.' in options:
        raise OperationFailure("the count field cannot contain '.'")
    return [{options: len(in_collection)}]
Exemple #4
0
    def _handle_date_operator(self, operator, values):
        out_value = self.parse(values)
        if operator == '$dayOfYear':
            return out_value.timetuple().tm_yday
        if operator == '$dayOfMonth':
            return out_value.day
        if operator == '$dayOfWeek':
            return (out_value.isoweekday() % 7) + 1
        if operator == '$year':
            return out_value.year
        if operator == '$month':
            return out_value.month
        if operator == '$week':
            return int(out_value.strftime('%U'))
        if operator == '$hour':
            return out_value.hour
        if operator == '$minute':
            return out_value.minute
        if operator == '$second':
            return out_value.second
        if operator == '$millisecond':
            return int(out_value.microsecond / 1000)
        if operator == '$dateToString':
            if not isinstance(values, dict):
                raise OperationFailure(
                    '$dateToString operator must correspond a dict'
                    'that has "format" and "date" field.'
                )
            if not isinstance(values, dict) or not {'format', 'date'} <= set(values):
                raise OperationFailure(
                    '$dateToString operator must correspond a dict'
                    'that has "format" and "date" field.'
                )
            if '%L' in out_value['format']:
                raise NotImplementedError(
                    'Although %L is a valid date format for the '
                    '$dateToString operator, it is currently not implemented '
                    ' in Mongomock.'
                )
            if 'onNull' in values:
                raise NotImplementedError(
                    'Although onNull is a valid field for the '
                    '$dateToString operator, it is currently not implemented '
                    ' in Mongomock.'
                )
            if 'timezone' in values.keys():
                raise NotImplementedError(
                    'Although timezone is a valid field for the '
                    '$dateToString operator, it is currently not implemented '
                    ' in Mongomock.'
                )
            return out_value['date'].strftime(out_value['format'])

        raise NotImplementedError(
            "Although '%s' is a valid date operator for the "
            'aggregation pipeline, it is currently not implemented '
            ' in Mongomock.' % operator)
Exemple #5
0
def _handle_sample_stage(in_collection, unused_database, options):
    if not isinstance(options, dict):
        raise OperationFailure('the $sample stage specification must be an object')
    size = options.pop('size', None)
    if size is None:
        raise OperationFailure('$sample stage must specify a size')
    if options:
        raise OperationFailure('unrecognized option to $sample: %s' % set(options).pop())
    shuffled = list(in_collection)
    _random.shuffle(shuffled)
    return shuffled[:size]
Exemple #6
0
def _handle_project_stage(in_collection, unused_database, options):
    filter_list = []
    method = None
    include_id = options.get('_id')
    # Compute new values for each field, except inclusion/exclusions that are
    # handled in one final step.
    new_fields_collection = None
    for field, value in six.iteritems(options):
        if method is None and (field != '_id' or value):
            method = 'include' if value else 'exclude'
        elif method == 'include' and not value and field != '_id':
            raise OperationFailure(
                'Bad projection specification, cannot exclude fields '
                "other than '_id' in an inclusion projection: %s" % options)
        elif method == 'exclude' and value:
            raise OperationFailure(
                'Bad projection specification, cannot include fields '
                'or add computed fields during an exclusion projection: %s' %
                options)
        if value in (0, 1, True, False):
            if field != '_id':
                filter_list.append(field)
            continue
        if not new_fields_collection:
            new_fields_collection = [{} for unused_doc in in_collection]

        for in_doc, out_doc in zip(in_collection, new_fields_collection):
            try:
                out_doc[field] = _parse_expression(value,
                                                   in_doc,
                                                   ignore_missing_keys=True)
            except KeyError:
                pass
    if (method == 'include') == (include_id is not False
                                 and include_id is not 0):
        filter_list.append('_id')

    if not filter_list:
        return new_fields_collection

    # Final steps: include or exclude fields and merge with newly created fields.
    projection_spec = _combine_projection_spec(filter_list,
                                               original_filter=options)
    out_collection = [
        _project_by_spec(doc,
                         projection_spec,
                         is_include=(method == 'include'))
        for doc in in_collection
    ]
    if new_fields_collection:
        return [
            dict(a, **b) for a, b in zip(out_collection, new_fields_collection)
        ]
    return out_collection
Exemple #7
0
    def parse(self, expression):
        """Parse a MongoDB expression."""
        if not isinstance(expression, dict):
            return self._parse_basic_expression(expression)

        if len(expression) > 1 and any(
                key.startswith('$') for key in expression):
            raise OperationFailure(
                'an expression specification must contain exactly one field, '
                'the name of the expression. Found %d fields in %s' %
                (len(expression), expression))

        value_dict = {}
        for k, v in six.iteritems(expression):
            if k in arithmetic_operators:
                return self._handle_arithmetic_operator(k, v)
            if k in project_operators:
                return self._handle_project_operator(k, v)
            if k in projection_operators:
                return self._handle_projection_operator(k, v)
            if k in comparison_operators:
                return self._handle_comparison_operator(k, v)
            if k in date_operators:
                return self._handle_date_operator(k, v)
            if k in array_operators:
                return self._handle_array_operator(k, v)
            if k in conditional_operators:
                return self._handle_conditional_operator(k, v)
            if k in set_operators:
                return self._handle_set_operator(k, v)
            if k in string_operators:
                return self._handle_string_operator(k, v)
            if k in type_convertion_operators:
                return self._handle_type_convertion_operator(k, v)
            if k in boolean_operators + \
                    text_search_operators + projection_operators:
                raise NotImplementedError(
                    "'%s' is a valid operation but it is not supported by Mongomock yet."
                    % k)
            if k.startswith('$'):
                raise OperationFailure("Unrecognized expression '%s'" % k)
            try:
                value = self.parse(v)
            except KeyError:
                if self._ignore_missing_keys:
                    continue
                raise
            value_dict[k] = value

        return value_dict
Exemple #8
0
    def parse(self, expression):
        """Parse a MongoDB expression."""
        if not isinstance(expression, dict):
            return self._parse_basic_expression(expression)

        value_dict = {}
        for k, v in six.iteritems(expression):
            if k in arithmetic_operators:
                return self._handle_arithmetic_operator(k, v)
            if k in project_operators:
                return self._handle_project_operator(k, v)
            if k in projection_operators:
                return self._handle_projection_operator(k, v)
            if k in comparison_operators:
                return self._handle_comparison_operator(k, v)
            if k in date_operators:
                return self._handle_date_operator(k, v)
            if k in array_operators:
                return self._handle_array_operator(k, v)
            if k in conditional_operators:
                return self._handle_conditional_operator(k, v)
            if k in set_operators:
                return self._handle_set_operator(k, v)
            if k in boolean_operators + string_operators + \
                    text_search_operators + projection_operators:
                raise NotImplementedError(
                    "'%s' is a valid operation but it is not supported by Mongomock yet."
                    % k)
            if k.startswith('$'):
                raise OperationFailure("Unrecognized expression '%s'" % k)
            value_dict[k] = self.parse(v)

        return value_dict
Exemple #9
0
def _handle_replace_root_stage(in_collection, unused_database, options):
    if 'newRoot' not in options:
        raise OperationFailure("Parameter 'newRoot' is missing for $replaceRoot operation.")
    new_root = options['newRoot']
    out_collection = []
    for doc in in_collection:
        try:
            new_doc = _parse_expression(new_root, doc, ignore_missing_keys=True)
        except KeyError:
            new_doc = NOTHING
        if not isinstance(new_doc, dict):
            raise OperationFailure(
                "'newRoot' expression must evaluate to an object, but resulting value was: {}"
                .format(new_doc))
        out_collection.append(new_doc)
    return out_collection
Exemple #10
0
 def _get_default_bucket():
     try:
         return options['default']
     except KeyError:
         raise OperationFailure(
             '$bucket could not find a matching branch for '
             'an input, and no default was specified.')
Exemple #11
0
 def _handle_array_operator(self, operator, value):
     if operator == '$size':
         if isinstance(value, list):
             if len(value) != 1:
                 raise OperationFailure(
                     'Expression $size takes exactly 1 arguments. '
                     '%d were passed in.' % len(value))
             value = value[0]
         array_value = self.parse(value)
         if not isinstance(array_value, list):
             raise OperationFailure(
                 'The argument to $size must be an array, '
                 'but was of type: %s' % type(array_value))
         return len(array_value)
     raise NotImplementedError(
         "Although '%s' is a valid array operator for the "
         'aggregation pipeline, it is currently not implemented '
         'in Mongomock.' % operator)
Exemple #12
0
 def _handle_string_operator(self, operator, values):
     if operator == '$toLower':
         parsed = self.parse(values)
         return str(parsed).lower() if parsed is not None else ''
     if operator == '$toUpper':
         parsed = self.parse(values)
         return str(parsed).upper() if parsed is not None else ''
     if operator == '$concat':
         parsed_list = [self.parse(value) for value in values]
         return None if None in parsed_list else ''.join(
             [str(x) for x in parsed_list])
     if operator == '$substr':
         if len(values) != 3:
             raise OperationFailure('substr must have 3 items')
         string = str(self.parse(values[0]))
         first = self.parse(values[1])
         length = self.parse(values[2])
         if string is None:
             return ''
         if first < 0:
             warnings.warn(
                 'Negative starting point given to $substr is accepted only until '
                 'MongoDB 3.7. This behavior will change in the future.')
             return ''
         if length < 0:
             warnings.warn(
                 'Negative length given to $substr is accepted only until '
                 'MongoDB 3.7. This behavior will change in the future.')
         second = len(string) if length < 0 else first + length
         return string[first:second]
     if operator == '$strcasecmp':
         if len(values) != 2:
             raise OperationFailure('strcasecmp must have 2 items')
         a, b = str(self.parse(values[0])), str(self.parse(values[1]))
         return 0 if a == b else -1 if a < b else 1
     if operator == '$toString':
         parsed = self.parse(values)
         return str(parsed) if parsed is not None else None
     # This should never happen: it is only a safe fallback if something went wrong.
     raise NotImplementedError(  # pragma: no cover
         "Although '%s' is a valid string operator for the aggregation "
         'pipeline, it is currently not implemented  in Mongomock.' %
         operator)
Exemple #13
0
def _handle_lookup_stage(in_collection, database, options):
    for operator in ('let', 'pipeline'):
        if operator in options:
            raise NotImplementedError(
                "Although '%s' is a valid lookup operator for the "
                'aggregation pipeline, it is currently not '
                'implemented in Mongomock.' % operator)
    for operator in ('from', 'localField', 'foreignField', 'as'):
        if operator not in options:
            raise OperationFailure(
                "Must specify '%s' field for a $lookup" % operator)
        if not isinstance(options[operator], six.string_types):
            raise OperationFailure(
                'Arguments to $lookup must be strings')
        if operator in ('as', 'localField', 'foreignField') and \
                options[operator].startswith('$'):
            raise OperationFailure(
                "FieldPath field names may not start with '$'")
        if operator == 'as' and \
                '.' in options[operator]:
            raise NotImplementedError(
                "Although '.' is valid in the 'as' "
                'parameters for the lookup stage of the aggregation '
                'pipeline, it is currently not implemented in Mongomock.')

    foreign_name = options['from']
    local_field = options['localField']
    foreign_field = options['foreignField']
    local_name = options['as']
    foreign_collection = database.get_collection(foreign_name)
    for doc in in_collection:
        try:
            query = helpers.get_value_by_dot(doc, local_field)
        except KeyError:
            query = None
        if isinstance(query, list):
            query = {'$in': query}
        matches = foreign_collection.find({foreign_field: query})
        doc[local_name] = [foreign_doc for foreign_doc in matches]

    return in_collection
Exemple #14
0
def _handle_add_fields_stage(in_collection, unused_database, options):
    if not options:
        raise OperationFailure(
            'Invalid $addFields :: caused by :: specification must have at least one field')
    out_collection = [dict(doc) for doc in in_collection]
    for field, value in six.iteritems(options):
        for in_doc, out_doc in zip(in_collection, out_collection):
            try:
                out_value = _parse_expression(value, in_doc, ignore_missing_keys=True)
            except KeyError:
                continue
            parts = field.split('.')
            for subfield in parts[:-1]:
                out_doc[subfield] = out_doc.get(subfield, {})
                if not isinstance(out_doc[subfield], dict):
                    out_doc[subfield] = {}
                out_doc = out_doc[subfield]
            out_doc[parts[-1]] = out_value
    return out_collection
Exemple #15
0
    def _handle_type_convertion_operator(self, operator, values):
        if operator == '$toString':
            try:
                parsed = self.parse(values)
            except KeyError:
                return None
            if isinstance(parsed, bool):
                return str(parsed).lower()
            if isinstance(parsed, datetime.datetime):
                return parsed.isoformat()[:-3] + 'Z'
            return str(parsed)

        if operator == '$toInt':
            try:
                parsed = self.parse(values)
            except KeyError:
                return None
            if decimal_support:
                if isinstance(parsed, decimal128.Decimal128):
                    return int(parsed.to_decimal())
                return int(parsed)
            raise NotImplementedError(
                'You need to import the pymongo library to support decimal128 type.'
            )

        # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/toDecimal/
        if operator == '$toDecimal':
            if not decimal_support:
                raise NotImplementedError(
                    'You need to import the pymongo library to support decimal128 type.'
                )
            try:
                parsed = self.parse(values)
            except KeyError:
                return None
            if isinstance(parsed, bool):
                parsed = '1' if parsed is True else '0'
                decimal_value = decimal128.Decimal128(parsed)
            elif isinstance(parsed, int):
                decimal_value = decimal128.Decimal128(str(parsed))
            elif isinstance(parsed, float):
                exp = decimal.Decimal('.00000000000000')
                decimal_value = decimal.Decimal(str(parsed)).quantize(exp)
                decimal_value = decimal128.Decimal128(decimal_value)
            elif isinstance(parsed, decimal128.Decimal128):
                decimal_value = parsed
            elif isinstance(parsed, str):
                try:
                    decimal_value = decimal128.Decimal128(parsed)
                except decimal.InvalidOperation:
                    raise OperationFailure(
                        "Failed to parse number '%s' in $convert with no onError value:"
                        'Failed to parse string to decimal' % parsed)
            elif isinstance(parsed, datetime.datetime):
                epoch = datetime.datetime.utcfromtimestamp(0)
                string_micro_seconds = str((parsed - epoch).total_seconds() * 1000).split('.')[0]
                decimal_value = decimal128.Decimal128(string_micro_seconds)
            else:
                raise TypeError("'%s' type is not supported" % type(parsed))
            return decimal_value

        # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/arrayToObject/
        if operator == '$arrayToObject':
            try:
                parsed = self.parse(values)
            except KeyError:
                return None

            if parsed is None:
                return None

            if not isinstance(parsed, (list, tuple)):
                raise OperationFailure(
                    '$arrayToObject requires an array input, found: {}'.format(type(parsed))
                )

            if all(isinstance(x, dict) and set(x.keys()) == {'k', 'v'} for x in parsed):
                return {d['k']: d['v'] for d in parsed}

            if all(isinstance(x, (list, tuple)) and len(x) == 2 for x in parsed):
                return dict(parsed)

            raise OperationFailure(
                'arrays used with $arrayToObject must contain documents '
                'with k and v fields or two-element arrays'
            )
Exemple #16
0
 def _handle_arithmetic_operator(self, operator, values):
     unary_operators = {
         '$abs': abs,
         '$ceil': math.ceil,
         '$exp': math.exp,
         '$floor': math.floor,
         '$ln': math.log,
         '$log10': math.log10,
         '$sqrt': math.sqrt,
         '$trunc': math.trunc
     }
     binary_operators = {
         '$divide': lambda x, y: x / y,
         '$log': math.log,
         '$mod': math.fmod,
         '$pow': math.pow,
         '$subtract': lambda x, y: x - y,
     }
     if operator in unary_operators:
         if isinstance(values, list):
             if len(values) != 1:
                 raise OperationFailure(
                     'Expression %s takes exactly 1 arguments. %d were passed in.'
                     % (operator, len(values)))
             values = values[0]
         try:
             return None if self.parse(
                 values) is None else unary_operators[operator](
                     self.parse(values))
         except KeyError:
             return None
     if operator in binary_operators:
         if not isinstance(values, list):
             raise OperationFailure(
                 'Expression %s takes exactly 2 arguments. 1 were passed in.'
                 % operator)
         if len(values) != 2:
             raise OperationFailure(
                 'Expression %s takes exactly 2 arguments. %d were passed in.'
                 % (operator, len(values)))
         try:
             values = list(map(self.parse, values))
         except KeyError:
             return None
         if None in values:
             return None
         if operator == '$subtract' and isinstance(values[0],
                                                   datetime.datetime):
             if isinstance(values[1], datetime.datetime):
                 return round(
                     (values[0] - values[1]).total_seconds() * 1000)
             return values[0] - datetime.timedelta(milliseconds=values[1])
         return binary_operators[operator](values[0], values[1])
     if operator == '$add':
         if not isinstance(values, list):
             values = [values]
         parsed_dates = []
         parsed_numbers = []
         for value in values:
             try:
                 value = self.parse(value)
                 if value is None:
                     return None
                 elif isinstance(value, datetime.datetime):
                     parsed_dates.append(value)
                 else:
                     parsed_numbers.append(value)
             except KeyError:
                 return None
         if len(parsed_dates) > 1:
             raise OperationFailure(
                 'only one date allowed in an $add expression')
         result = sum(parsed_numbers)
         if len(parsed_dates) == 1:
             result = parsed_dates[0] + datetime.timedelta(
                 milliseconds=result)
         return result
     if operator == '$multiply':
         if not isinstance(values, list):
             values = [values]
         parsed_values = []
         for value in values:
             try:
                 parsed_values.append(self.parse(value))
             except KeyError:
                 return None
         if None in parsed_values:
             return None
         return moves.reduce(lambda x, y: x * y, parsed_values, 1)
     # This should never happen: it is only a safe fallback if something went wrong.
     raise NotImplementedError(  # pragma: no cover
         "Although '%s' is a valid aritmetic operator for the aggregation "
         'pipeline, it is currently not implemented  in Mongomock.' %
         operator)
Exemple #17
0
    def _handle_type_convertion_operator(self, operator, values):
        if operator == '$toString':
            try:
                parsed = self.parse(values)
            except KeyError:
                return None
            if isinstance(parsed, bool):
                return str(parsed).lower()
            if isinstance(parsed, datetime.datetime):
                return parsed.isoformat()[:-3] + 'Z'
            return str(parsed)

        if operator == '$toInt':
            try:
                parsed = self.parse(values)
            except KeyError:
                return None
            if decimal_support:
                if isinstance(parsed, decimal128.Decimal128):
                    return int(parsed.to_decimal())
                return int(parsed)
            raise NotImplementedError(
                'You need to import the pymongo library to support decimal128 type.'
            )

        # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/toDecimal/
        if operator == '$toDecimal':
            if not decimal_support:
                raise NotImplementedError(
                    'You need to import the pymongo library to support decimal128 type.'
                )
            try:
                parsed = self.parse(values)
            except KeyError:
                return None
            if isinstance(parsed, bool):
                parsed = '1' if parsed is True else '0'
                decimal_value = decimal128.Decimal128(parsed)
            elif isinstance(parsed, int):
                decimal_value = decimal128.Decimal128(str(parsed))
            elif isinstance(parsed, float):
                exp = decimal.Decimal('.00000000000000')
                decimal_value = decimal.Decimal(str(parsed)).quantize(exp)
                decimal_value = decimal128.Decimal128(decimal_value)
            elif isinstance(parsed, decimal128.Decimal128):
                decimal_value = parsed
            elif isinstance(parsed, str):
                try:
                    decimal_value = decimal128.Decimal128(parsed)
                except decimal.InvalidOperation:
                    raise OperationFailure(
                        "Failed to parse number '%s' in $convert with no onError value:"
                        'Failed to parse string to decimal' % parsed)
            elif isinstance(parsed, datetime.datetime):
                epoch = datetime.datetime.utcfromtimestamp(0)
                string_micro_seconds = str(
                    (parsed - epoch).total_seconds() * 1000).split('.')[0]
                decimal_value = decimal128.Decimal128(string_micro_seconds)
            else:
                raise TypeError("'%s' type is not supported" % type(parsed))
            return decimal_value
Exemple #18
0
def _handle_bucket_stage(in_collection, unused_database, options):
    unknown_options = set(options) - {
        'groupBy', 'boundaries', 'output', 'default'
    }
    if unknown_options:
        raise OperationFailure('Unrecognized option to $bucket: %s.' %
                               unknown_options.pop())
    if 'groupBy' not in options or 'boundaries' not in options:
        raise OperationFailure(
            "$bucket requires 'groupBy' and 'boundaries' to be specified.")
    group_by = options['groupBy']
    boundaries = options['boundaries']
    if not isinstance(boundaries, list):
        raise OperationFailure(
            "The $bucket 'boundaries' field must be an array, but found type: %s"
            % type(boundaries))
    if len(boundaries) < 2:
        raise OperationFailure(
            "The $bucket 'boundaries' field must have at least 2 values, but "
            'found %d value(s).' % len(boundaries))
    if sorted(boundaries) != boundaries:
        raise OperationFailure(
            "The 'boundaries' option to $bucket must be sorted in ascending order"
        )
    output_fields = options.get('output', {'count': {'$sum': 1}})
    default_value = options.get('default', None)
    try:
        is_default_last = default_value >= boundaries[-1]
    except TypeError:
        is_default_last = True

    def _get_default_bucket():
        try:
            return options['default']
        except KeyError:
            raise OperationFailure(
                '$bucket could not find a matching branch for '
                'an input, and no default was specified.')

    def _get_bucket_id(doc):
        """Get the bucket ID for a document.

        Note that it actually returns a tuple with the first
        param being a sort key to sort the default bucket even
        if it's not the same type as the boundaries.
        """
        try:
            value = _parse_expression(group_by, doc)
        except KeyError:
            return (is_default_last, _get_default_bucket())
        index = bisect.bisect_right(boundaries, value)
        if index and index < len(boundaries):
            return (False, boundaries[index - 1])
        return (is_default_last, _get_default_bucket())

    in_collection = ((_get_bucket_id(doc), doc) for doc in in_collection)
    out_collection = sorted(in_collection, key=lambda kv: kv[0])
    grouped = itertools.groupby(out_collection, lambda kv: kv[0])

    out_collection = []
    for (unused_key, doc_id), group in grouped:
        group_list = [kv[1] for kv in group]
        doc_dict = _accumulate_group(output_fields, group_list)
        doc_dict['_id'] = doc_id
        out_collection.append(doc_dict)
    return out_collection
Exemple #19
0
def _handle_graph_lookup_stage(in_collection, database, options):
    if not isinstance(options.get('maxDepth', 0), six.integer_types):
        raise OperationFailure(
            "Argument 'maxDepth' to $graphLookup must be a number")
    if not isinstance(options.get('restrictSearchWithMatch', {}), dict):
        raise OperationFailure(
            "Argument 'restrictSearchWithMatch' to $graphLookup must be a Dictionary"
        )
    if not isinstance(options.get('depthField', ''), six.string_types):
        raise OperationFailure(
            "Argument 'depthField' to $graphlookup must be a string")
    if 'startWith' not in options:
        raise OperationFailure(
            "Must specify 'startWith' field for a $graphLookup")
    for operator in ('as', 'connectFromField', 'connectToField', 'from'):
        if operator not in options:
            raise OperationFailure(
                "Must specify '%s' field for a $graphLookup" % operator)
        if not isinstance(options[operator], six.string_types):
            raise OperationFailure(
                "Argument '%s' to $graphLookup must be string" % operator)
        if options[operator].startswith('$'):
            raise OperationFailure(
                "FieldPath field names may not start with '$'")
        if operator in ('connectFromField', 'as') and \
                '.' in options[operator]:
            raise NotImplementedError(
                "Although '.' is valid in the '%s' "
                'parameter for the $graphLookup stage of the aggregation '
                'pipeline, it is currently not implemented in Mongomock.' %
                operator)

    foreign_name = options['from']
    start_with = options['startWith']
    connect_from_field = options['connectFromField']
    connect_to_field = options['connectToField']
    local_name = options['as']
    max_depth = options.get('maxDepth', None)
    depth_field = options.get('depthField', None)
    restrict_search_with_match = options.get('restrictSearchWithMatch', {})
    foreign_collection = database.get_collection(foreign_name)
    out_doc = copy.deepcopy(in_collection)  # TODO(pascal): speed the deep copy

    def _find_matches_for_depth(query):
        if isinstance(query, list):
            query = {'$in': query}
        matches = foreign_collection.find({connect_to_field: query})
        new_matches = []
        for new_match in matches:
            if filtering.filter_applies(restrict_search_with_match, new_match) \
                    and new_match['_id'] not in found_items:
                if depth_field is not None:
                    new_match = collections.OrderedDict(
                        new_match, **{depth_field: depth})
                new_matches.append(new_match)
                found_items.add(new_match['_id'])
        return new_matches

    for doc in out_doc:
        found_items = set()
        depth = 0
        result = _parse_expression(start_with, doc)
        origin_matches = doc[local_name] = _find_matches_for_depth(result)
        while origin_matches and (max_depth is None or depth < max_depth):
            depth += 1
            newly_discovered_matches = []
            for match in origin_matches:
                match_target = match.get(connect_from_field)
                newly_discovered_matches += _find_matches_for_depth(
                    match_target)
            doc[local_name] += newly_discovered_matches
            origin_matches = newly_discovered_matches
    return out_doc
Exemple #20
0
    def _handle_array_operator(self, operator, value):
        if operator == '$size':
            if isinstance(value, list):
                if len(value) != 1:
                    raise OperationFailure('Expression $size takes exactly 1 arguments. '
                                           '%d were passed in.' % len(value))
                value = value[0]
            array_value = self.parse(value)
            if not isinstance(array_value, list):
                raise OperationFailure('The argument to $size must be an array, '
                                       'but was of type: %s' % type(array_value))
            return len(array_value)

        if operator == '$filter':
            if not isinstance(value, dict):
                raise OperationFailure('$filter only supports an object as its argument')
            extra_params = set(value) - {'input', 'cond', 'as'}
            if extra_params:
                raise OperationFailure('Unrecognized parameter to $filter: %s' % extra_params.pop())
            missing_params = {'input', 'cond'} - set(value)
            if missing_params:
                raise OperationFailure("Missing '%s' parameter to $filter" % missing_params.pop())

            input_array = self.parse(value['input'])
            fieldname = value.get('as', 'this')
            cond = value['cond']
            return [
                item for item in input_array
                if _Parser(
                    self._doc_dict,
                    dict(self._user_vars, **{fieldname: item}),
                    ignore_missing_keys=self._ignore_missing_keys,
                ).parse(cond)
            ]
        if operator == '$slice':
            if not isinstance(value, list):
                raise OperationFailure('$slice only supports a list as its argument')
            if len(value) < 2 or len(value) > 3:
                raise OperationFailure('Expression $slice takes at least 2 arguments, and at most '
                                       '3, but {} were passed in'.format(len(value)))
            array_value = self.parse(value[0])
            if not isinstance(array_value, list):
                raise OperationFailure(
                    'First argument to $slice must be an array, but is of type: {}'
                    .format(type(array_value)))
            for num, v in zip(('Second', 'Third'), value[1:]):
                if not isinstance(v, six.integer_types):
                    raise OperationFailure(
                        '{} argument to $slice must be numeric, but is of type: {}'
                        .format(num, type(v)))
            if len(value) > 2 and value[2] <= 0:
                raise OperationFailure('Third argument to $slice must be '
                                       'positive: {}'.format(value[2]))

            start = value[1]
            if start < 0:
                if len(value) > 2:
                    stop = len(array_value) + start + value[2]
                else:
                    stop = None
            elif len(value) > 2:
                stop = start + value[2]
            else:
                stop = start
                start = 0
            return array_value[start:stop]

        raise NotImplementedError(
            "Although '%s' is a valid array operator for the "
            'aggregation pipeline, it is currently not implemented '
            'in Mongomock.' % operator)