예제 #1
0
 def __init__(self, document, collection):
     self._document = document
     self._collection_obj = collection
     self._accessed_collection = False
     self._query = {}
     self._where_clause = None
     self._loaded_fields = []
     self._ordering = []
     self.transform = TransformDjango()
     
     # If inheritance is allowed, only return instances and instances of
     # subclasses of the class being used
     #if document._meta.get('allow_inheritance'):
         #self._query = {'_types': self._document._class_name}
     self._cursor_obj = None
     self._limit = None
     self._skip = None
예제 #2
0
class QuerySet(object):
    """A set of results returned from a query. Wraps a MongoDB cursor,
    providing :class:`~mongoengine.Document` objects as the results.
    """

    def __init__(self, document, collection):
        self._document = document
        self._collection_obj = collection
        self._accessed_collection = False
        self._query = {}
        self._where_clause = None
        self._loaded_fields = []
        self._ordering = []
        self.transform = TransformDjango()
        
        # If inheritance is allowed, only return instances and instances of
        # subclasses of the class being used
        #if document._meta.get('allow_inheritance'):
            #self._query = {'_types': self._document._class_name}
        self._cursor_obj = None
        self._limit = None
        self._skip = None

        #required for compatibility with django
        #self.model = InternalModel(document)

    def __call__(self, q_obj=None, **query):
        """Filter the selected documents by calling the
        :class:`~mongoengine.queryset.QuerySet` with a query.

        :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in
            the query; the :class:`~mongoengine.queryset.QuerySet` is filtered
            multiple times with different :class:`~mongoengine.queryset.Q`
            objects, only the last one will be used
        :param query: Django-style query keyword arguments
        """
        if q_obj:
            self._where_clause = q_obj.as_js(self._document)
        query = QuerySet._transform_query(_doc_cls=self._document, **query)
        self._query.update(query)
        return self

    def filter(self, *q_objs, **query):
        """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__`
        """
        return self.__call__(*q_objs, **query)

    def find(self, query):
        self._query.update(self.transform.transform_incoming(query, self._collection))
        return self

    def exclude(self, *q_objs, **query):
        """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__`
        """
        query["not"] = True
        return self.__call__(*q_objs, **query)

    def all(self):
        """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__`
        """
        return self.__call__()
    
    def distinct(self, *args, **kwargs):
        """
        Distinct method
        """
        return self._cursor.distinct(*args, **kwargs)

    @property
    def _collection(self):
        """Property that returns the collection object. This allows us to
        perform operations only if the collection is accessed.
        """
        return self._collection_obj
    
    def values(self, *args):
        return (args and [dict(zip(args,[getattr(doc, key) for key in args])) for doc in self]) or [obj for obj in self._cursor.clone()]
        
    def values_list(self, *args, **kwargs):
        flat = kwargs.pop("flat", False)
        if flat and len(args) != 1:
            raise Exception("args len must be 1 when flat=True")
        
        return (flat and self.distinct(args[0] if not args[0] in ["id", "pk"] else "_id")) or zip(*[self.distinct(field if not field in ["id", "pk"] else "_id") for field in args])
#                
#            if self._document._meta['geo_indexes'] and \
#               pymongo.version >= "1.5.1":
#                from pymongo import GEO2D
#                for index in self._document._meta['geo_indexes']:
#                    self._collection.ensure_index([(index, GEO2D)])
#            
#            # Ensure all needed field indexes are created
#            for field_name, field_instance in self._document._fields.iteritems():
#                if field_instance.__class__.__name__ == 'GeoLocationField':
#                    self._collection.ensure_index([(field_name, pymongo.GEO2D),])
#        return self._collection_obj

    @property
    def _cursor(self):
        if self._cursor_obj is None:
            cursor_args = {}
            if self._loaded_fields:
                cursor_args = {'fields': self._loaded_fields}
            self._cursor_obj = self._collection.find(self._query, 
                                                     **cursor_args)
            # Apply where clauses to cursor
            if self._where_clause:
                self._cursor_obj.where(self._where_clause)

            # apply default ordering
#            if self._document._meta['ordering']:
#                self.order_by(*self._document._meta['ordering'])

        return self._cursor_obj.clone()

    @classmethod
    def _lookup_field(cls, document, fields):
        """
        Looks for "field" in "document"
        """
        if isinstance(fields, (tuple, list)):
            return [document._meta.get_field_by_name((field == "pk" and "id") or field)[0] for field in fields]
        return document._meta.get_field_by_name((fields == "pk" and "id") or fields)[0]

    @classmethod
    def _translate_field_name(cls, doc_cls, field, sep='.'):
        """Translate a field attribute name to a database field name.
        """
        parts = field.split(sep)
        parts = [f.attname for f in QuerySet._lookup_field(doc_cls, parts)]
        return '.'.join(parts)

    @classmethod
    def _transform_query(self,  _doc_cls=None, **parameters):
        """
        Converts parameters to mongodb queries. 
        """
        spec = {}
        operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists']
        match_operators = ['contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact']
        exclude = parameters.pop("not", False)
        
        for key, value in parameters.items():
            
            
            parts  = key.split("__")
            lookup_type = (len(parts)>=2) and ( parts[-1] in operators + match_operators and parts.pop()) or ""
            
            # Let's get the right field and be sure that it exists
            parts[0] = QuerySet._lookup_field(_doc_cls, parts[0]).attname
            
            if not lookup_type and len(parts)==1:
                if exclude:
                    value = {"$ne" : value}
                spec.update({parts[0] : value})
                continue
            
            if parts[0] == "id":
                parts[0] = "_id"
                value = [isinstance(par, basestring) and ObjectId(par) or par for par in value]
                
            if lookup_type in ['contains', 'icontains',
                                 'startswith', 'istartswith',
                                 'endswith', 'iendswith',
                                 'exact', 'iexact']:
                flags = 0
                if lookup_type.startswith('i'):
                    flags = re.IGNORECASE
                    lookup_type = lookup_type.lstrip('i')
                    
                regex = r'%s'
                if lookup_type == 'startswith':
                    regex = r'^%s'
                elif lookup_type == 'endswith':
                    regex = r'%s$'
                elif lookup_type == 'exact':
                    regex = r'^%s$'
                    
                value = re.compile(regex % value, flags)
                
            elif lookup_type in operators:
                value = { "$" + lookup_type : value}
            elif lookup_type and len(parts)==1:
                raise DatabaseError("Unsupported lookup type: %r" % lookup_type)
    
            key = '.'.join(parts)
            if exclude:
                value = {"$ne" : value}
            spec.update({key : value})
            
        return spec
    
    def get(self, *q_objs, **query):
        """Retrieve the the matching object raising id django is available
        :class:`~django.core.exceptions.MultipleObjectsReturned` or
        :class:`~django.core.exceptions.ObjectDoesNotExist` exceptions if multiple or
        no results are found.
        If django is not available:
        :class:`~mongoengine.queryset.MultipleObjectsReturned` or
        `DocumentName.MultipleObjectsReturned` exception if multiple results and
        :class:`~mongoengine.queryset.DoesNotExist` or `DocumentName.DoesNotExist`
        if no results are found.

        .. versionadded:: 0.3
        """
        self.__call__(*q_objs, **query)
        count = self.count()
        if count == 1:
            return self[0]
        elif count > 1:
            message = u'%d items returned, instead of 1' % count
            raise self._document.MultipleObjectsReturned(message)
        else:
            raise self._document.DoesNotExist("%s matching query does not exist."
                                              % self._document._meta.object_name)

    def get_or_create(self, *q_objs, **query):
        """Retrieve unique object or create, if it doesn't exist. Returns a tuple of 
        ``(object, created)``, where ``object`` is the retrieved or created object 
        and ``created`` is a boolean specifying whether a new object was created. Raises
        :class:`~mongoengine.queryset.MultipleObjectsReturned` or
        `DocumentName.MultipleObjectsReturned` if multiple results are found.
        A new document will be created if the document doesn't exists; a
        dictionary of default values for the new document may be provided as a
        keyword argument called :attr:`defaults`.

        .. versionadded:: 0.3
        """
        defaults = query.get('defaults', {})
        if 'defaults' in query:
            del query['defaults']

        self.__call__(*q_objs, **query)
        count = self.count()
        if count == 0:
            query.update(defaults)
            doc = self._document(**query)
            doc.save()
            return doc, True
        elif count == 1:
            return self.first(), False
        else:
            message = u'%d items returned, instead of 1' % count
            raise self._document.MultipleObjectsReturned(message)

    def first(self):
        """Retrieve the first object matching the query.
        """
        try:
            result = self[0]
        except IndexError:
            result = None
        return result

    def with_id(self, object_id):
        """Retrieve the object matching the id provided.

        :param object_id: the value for the id of the document to look up
        """
        id_field = self._document._meta['id_field']
        object_id = self._document._fields[id_field].to_mongo(object_id)

        result = self._collection.find_one({'_id': (not isinstance(object_id, ObjectId) and ObjectId(object_id)) or object_id})
        if result is not None:
            result = self._document(**dict_keys_to_str(result))
        return result

    def in_bulk(self, object_ids):
        """Retrieve a set of documents by their ids.
        
        :param object_ids: a list or tuple of ``ObjectId``\ s
        :rtype: dict of ObjectIds as keys and collection-specific
                Document subclasses as values.

        .. versionadded:: 0.3
        """
        doc_map = {}

        docs = self._collection.find({'_id': {'$in': [ (not isinstance(id, ObjectId) and ObjectId(id)) or id for id in object_ids]}})
        for doc in docs:
            doc_map[str(doc['id'])] = self._document(**dict_keys_to_str(doc))
 
        return doc_map
    
    def count(self):
        """Count the selected elements in the query.
        """
        if self._limit == 0:
            return 0
        return self._cursor.count(with_limit_and_skip=False)

    def __len__(self):
        return self.count()

    def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None,
                   scope=None, keep_temp=False):
        """Perform a map/reduce query using the current query spec
        and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
        it must be the last call made, as it does not return a maleable
        ``QuerySet``.

        See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce`
        and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced`
        tests in ``tests.queryset.QuerySetTest`` for usage examples.

        :param map_f: map function, as :class:`~pymongo.code.Code` or string
        :param reduce_f: reduce function, as
                         :class:`~pymongo.code.Code` or string
        :param finalize_f: finalize function, an optional function that
                           performs any post-reduction processing.
        :param scope: values to insert into map/reduce global scope. Optional.
        :param limit: number of objects from current query to provide
                      to map/reduce method
        :param keep_temp: keep temporary table (boolean, default ``True``)

        Returns an iterator yielding
        :class:`~mongoengine.document.MapReduceDocument`.

        .. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo
           :meth:`~pymongo.collection.Collection.map_reduce` helper requires
           PyMongo version **>= 1.2**.

        .. versionadded:: 0.3
        """
        #from document import MapReduceDocument
        
        if not hasattr(self._collection, "map_reduce"):
            raise NotImplementedError("Requires MongoDB >= 1.1.1")

        map_f_scope = {}
        if isinstance(map_f, pymongo.code.Code):
            map_f_scope = map_f.scope
            map_f = unicode(map_f)
#        map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope)
        map_f = pymongo.code.Code(map_f, map_f_scope)

        reduce_f_scope = {}
        if isinstance(reduce_f, pymongo.code.Code):
            reduce_f_scope = reduce_f.scope
            reduce_f = unicode(reduce_f)
#        reduce_f_code = self._sub_js_fields(reduce_f)
        reduce_f_code = reduce_f
        reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope)

        mr_args = {'query': self._query, 'keeptemp': keep_temp}

        if finalize_f:
            finalize_f_scope = {}
            if isinstance(finalize_f, pymongo.code.Code):
                finalize_f_scope = finalize_f.scope
                finalize_f = unicode(finalize_f)
#            finalize_f_code = self._sub_js_fields(finalize_f)
            finalize_f_code = finalize_f
            finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope)
            mr_args['finalize'] = finalize_f

        if scope:
            mr_args['scope'] = scope

        if limit:
            mr_args['limit'] = limit

        results = self._collection.map_reduce(map_f, reduce_f, **mr_args)
        results = results.find()

        if self._ordering:
            results = results.sort(self._ordering)

        for doc in results:
            yield self._document.objects.with_id(doc['value'])

    def limit(self, n):
        """Limit the number of returned documents to `n`. This may also be
        achieved using array-slicing syntax (e.g. ``User.objects[:5]``).

        :param n: the maximum number of objects to return
        """
        if n == 0:
            self._cursor.limit(1)
        else:
            self._cursor.limit(n)
        self._limit = n

        # Return self to allow chaining
        return self

    def skip(self, n):
        """Skip `n` documents before returning the results. This may also be
        achieved using array-slicing syntax (e.g. ``User.objects[5:]``).

        :param n: the number of objects to skip before returning results
        """
        self._cursor.skip(n)
        self._skip = n
        return self

    def __getitem__(self, key):
        """Support skip and limit using getitem and slicing syntax.
        """
        # Slice provided
        if isinstance(key, slice):
            try:
                self._cursor_obj = self._cursor[key]
                self._skip, self._limit = key.start, key.stop
            except IndexError, err:
                # PyMongo raises an error if key.start == key.stop, catch it,
                # bin it, kill it. 
                start = key.start or 0
                if start >= 0 and key.stop >= 0 and key.step is None:
                    if start == key.stop:
                        self.limit(0)
                        self._skip, self._limit = key.start, key.stop - start
                        return self
                raise err
            # Allow further QuerySet modifications to be performed
            return self
        # Integer index provided
        elif isinstance(key, int):
            return self._document(**dict_keys_to_str(self._cursor[key]))