def test_create_collection(self):
        db = Database(self.client, "pymongo_test")

        db.test.insert({"hello": "world"})
        self.assertRaises(CollectionInvalid, db.create_collection, "test")

        db.drop_collection("test")

        self.assertRaises(TypeError, db.create_collection, 5)
        self.assertRaises(TypeError, db.create_collection, None)
        self.assertRaises(InvalidName, db.create_collection, "coll..ection")

        test = db.create_collection("test")
        test.save({"hello": u"world"})
        self.assertEqual(db.test.find_one()["hello"], "world")
        self.assertTrue(u"test" in db.collection_names())

        db.drop_collection("test.foo")
        db.create_collection("test.foo")
        self.assertTrue(u"test.foo" in db.collection_names())
        expected = {}
        if version.at_least(self.client, (2, 7, 0)):
            # usePowerOf2Sizes server default
            expected["flags"] = 1
        result = db.test.foo.options()
        # mongos 2.2.x adds an $auth field when auth is enabled.
        result.pop('$auth', None)
        self.assertEqual(result, expected)
        self.assertRaises(CollectionInvalid, db.create_collection, "test.foo")
Esempio n. 2
0
def temp_coll_name(evo_db: Database):
    coll_name = 'tmp'
    evo_db.create_collection(coll_name)

    yield coll_name

    evo_db.drop_collection(coll_name)
    def test_drop_collection(self):
        db = Database(self.client, "pymongo_test")

        self.assertRaises(TypeError, db.drop_collection, 5)
        self.assertRaises(TypeError, db.drop_collection, None)

        db.test.insert_one({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection("test")
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection(u"test")
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection(db.test)
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.test.drop()
        self.assertFalse("test" in db.collection_names())
        db.test.drop()

        db.drop_collection(db.test.doesnotexist)

        if client_context.version.at_least(3, 3, 9) and client_context.is_rs:
            db_wc = Database(self.client, 'pymongo_test',
                             write_concern=IMPOSSIBLE_WRITE_CONCERN)
            with self.assertRaises(WriteConcernError):
                db_wc.drop_collection('test')
    def test_drop_collection(self):
        db = Database(self.client, "pymongo_test")

        self.assertRaises(TypeError, db.drop_collection, 5)
        self.assertRaises(TypeError, db.drop_collection, None)

        db.test.insert_one({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection("test")
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection(u"test")
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection(db.test)
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.test.drop()
        self.assertFalse("test" in db.collection_names())
        db.test.drop()

        db.drop_collection(db.test.doesnotexist)

        if client_context.version.at_least(3, 3, 9) and client_context.is_rs:
            db_wc = Database(self.client, 'pymongo_test',
                             write_concern=IMPOSSIBLE_WRITE_CONCERN)
            with self.assertRaises(WriteConcernError):
                db_wc.drop_collection('test')
    def test_create_collection(self):
        db = Database(self.client, "pymongo_test")

        db.test.insert({"hello": "world"})
        self.assertRaises(CollectionInvalid, db.create_collection, "test")

        db.drop_collection("test")

        self.assertRaises(TypeError, db.create_collection, 5)
        self.assertRaises(TypeError, db.create_collection, None)
        self.assertRaises(InvalidName, db.create_collection, "coll..ection")

        test = db.create_collection("test")
        test.save({"hello": u"world"})
        self.assertEqual(db.test.find_one()["hello"], "world")
        self.assertTrue(u"test" in db.collection_names())

        db.drop_collection("test.foo")
        db.create_collection("test.foo")
        self.assertTrue(u"test.foo" in db.collection_names())
        expected = {}
        if version.at_least(self.client, (2, 7, 0)):
            # usePowerOf2Sizes server default
            expected["flags"] = 1
        result = db.test.foo.options()
        # mongos 2.2.x adds an $auth field when auth is enabled.
        result.pop("$auth", None)
        self.assertEqual(result, expected)
        self.assertRaises(CollectionInvalid, db.create_collection, "test.foo")
Esempio n. 6
0
 def binary_contents_test(self):
     db = Database(self._get_connection(), "pymongo_test")
     test = db.create_collection("test_binary")
     import os
     import bson
     obj = os.urandom(1024)
     test.save({"hello": bson.Binary(obj)})
     db.drop_collection("test_binary")
Esempio n. 7
0
def test_create_collection(mongo_handler: MongoHandler, evo_db: Database):
    new_col_name = 'new_infra_col'
    result = mongo_handler.create_collection(new_col_name)
    assert result
    assert evo_db[new_col_name] is not None

    result = mongo_handler.create_collection(new_col_name)
    assert not result

    evo_db.drop_collection(new_col_name)
    def test_drop_collection(self):
        db = Database(self.client, "pymongo_test")

        self.assertRaises(TypeError, db.drop_collection, 5)
        self.assertRaises(TypeError, db.drop_collection, None)

        db.test.insert_one({"dummy": u("object")})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection("test")
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u("object")})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection(u("test"))
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u("object")})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection(db.test)
        self.assertFalse("test" in db.collection_names())

        db.test.insert_one({"dummy": u("object")})
        self.assertTrue("test" in db.collection_names())
        db.test.drop()
        self.assertFalse("test" in db.collection_names())
        db.test.drop()

        db.drop_collection(db.test.doesnotexist)
    def test_drop_collection(self):
        db = Database(self.connection, "pymongo_test")

        self.assertRaises(TypeError, db.drop_collection, 5)
        self.assertRaises(TypeError, db.drop_collection, None)

        db.test.save({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection("test")
        self.assertFalse("test" in db.collection_names())

        db.test.save({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection(u"test")
        self.assertFalse("test" in db.collection_names())

        db.test.save({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.drop_collection(db.test)
        self.assertFalse("test" in db.collection_names())

        db.test.save({"dummy": u"object"})
        self.assertTrue("test" in db.collection_names())
        db.test.drop()
        self.assertFalse("test" in db.collection_names())
        db.test.drop()

        db.drop_collection(db.test.doesnotexist)
    def test_drop_collection(self):
        db = Database(self.connection, "pymongo_test")

        self.assertRaises(TypeError, db.drop_collection, 5)
        self.assertRaises(TypeError, db.drop_collection, None)

        db.test.save({"dummy": u"object"})
        self.assert_("test" in db.collection_names())
        db.drop_collection("test")
        self.assertFalse("test" in db.collection_names())

        db.test.save({"dummy": u"object"})
        self.assert_("test" in db.collection_names())
        db.drop_collection(u"test")
        self.assertFalse("test" in db.collection_names())

        db.test.save({"dummy": u"object"})
        self.assert_("test" in db.collection_names())
        db.drop_collection(db.test)
        self.assertFalse("test" in db.collection_names())

        db.test.save({"dummy": u"object"})
        self.assert_("test" in db.collection_names())
        db.test.drop()
        self.assertFalse("test" in db.collection_names())
        db.test.drop()

        db.drop_collection(db.test.doesnotexist)
Esempio n. 11
0
    def test4(self):
        db = Database(self._get_connection(), "pymongo_test")
        test = db.create_collection("test_4")
        try:
            for i in range(5):
                name = "test %d" % (i)
                test.save({ "user_id": i, "name": name, "group_id" : i % 10, "posts": i % 20})

            test.create_index("user_id")

            for i in xrange(6):
                for r in test.find( { "group_id": random.randint(0,10) } ):
                    print "Found: %s " % (r)

        finally:
            db.drop_collection("test_4")
    def test_create_collection(self):
        db = Database(self.client, "pymongo_test")

        db.test.insert_one({"hello": "world"})
        self.assertRaises(CollectionInvalid, db.create_collection, "test")

        db.drop_collection("test")

        self.assertRaises(TypeError, db.create_collection, 5)
        self.assertRaises(TypeError, db.create_collection, None)
        self.assertRaises(InvalidName, db.create_collection, "coll..ection")

        test = db.create_collection("test")
        self.assertTrue(u("test") in db.collection_names())
        test.insert_one({"hello": u("world")})
        self.assertEqual(db.test.find_one()["hello"], "world")

        db.drop_collection("test.foo")
        db.create_collection("test.foo")
        self.assertTrue(u("test.foo") in db.collection_names())
        self.assertRaises(CollectionInvalid, db.create_collection, "test.foo")
    def test_create_collection(self):
        db = Database(self.client, "pymongo_test")

        db.test.insert_one({"hello": "world"})
        self.assertRaises(CollectionInvalid, db.create_collection, "test")

        db.drop_collection("test")

        self.assertRaises(TypeError, db.create_collection, 5)
        self.assertRaises(TypeError, db.create_collection, None)
        self.assertRaises(InvalidName, db.create_collection, "coll..ection")

        test = db.create_collection("test")
        self.assertTrue(u"test" in db.collection_names())
        test.insert_one({"hello": u"world"})
        self.assertEqual(db.test.find_one()["hello"], "world")

        db.drop_collection("test.foo")
        db.create_collection("test.foo")
        self.assertTrue(u"test.foo" in db.collection_names())
        self.assertRaises(CollectionInvalid, db.create_collection, "test.foo")
    def test_create_collection(self):
        db = Database(self.connection, "pymongo_test")

        db.test.insert({"hello": "world"})
        self.assertRaises(CollectionInvalid, db.create_collection, "test")

        db.drop_collection("test")

        self.assertRaises(TypeError, db.create_collection, 5)
        self.assertRaises(TypeError, db.create_collection, None)
        self.assertRaises(InvalidName, db.create_collection, "coll..ection")

        test = db.create_collection("test")
        test.save({"hello": u"world"})
        self.assertEqual(db.test.find_one()["hello"], "world")
        self.assert_(u"test" in db.collection_names())

        db.drop_collection("test.foo")
        db.create_collection("test.foo")
        self.assert_(u"test.foo" in db.collection_names())
        self.assertEqual(db.test.foo.options(), {})
        self.assertRaises(CollectionInvalid, db.create_collection, "test.foo")
Esempio n. 15
0
    def test2(self):
        db = Database(self._get_connection(), "pymongo_test")
        test = db.create_collection("test_2")
        try:
            for i in range(100):
                name = "test %d" % (i)
                ret = test.save({"name": name, "group_id" : i % 3, "posts": i % 20})
                print "Save Ret: %s" % (ret)

            ret = test.update({"posts": 10}, {"$set": {"posts": 100}}, multi=True, safe=True)
            #ret = test.update({"posts": 10}, {"$set": {"posts": 100}}, multi=True)
            print "Update Ret: %s" % (ret)
            test.update({"name": "test 2"}, {"$set": {"posts": 200}})
            test.create_index("posts")
            test.ensure_index("posts")

            for r in test.find({"posts":100}):
                print "Found: %s" % (r,)

            ret = test.remove({"posts": 1}, safe=True)
            print "Remove Ret: %s" % (ret)

            groups = test.group(
                key={"group_id":1},
                condition=None,
                initial={"post_sum":0},
                reduce="function(obj,prev) {prev.post_sum++;}"
            )
            for g in groups:
                print "Group: %s" % (g,)

            for d in test.distinct('posts'):
                print "Distinct: %s" % (d,)

            if 'reindex' in dir(test):
                test.reindex()
            test.drop_indexes()
        finally:
            db.drop_collection("test_2")
Esempio n. 16
0
    def dbref_test(self):
        db = Database(self._get_connection(), "pymongo_test")

        try:
            db.create_collection('owners')
            db.create_collection('tasks')
            db.create_collection('tasks_ref')

            # owners and tasks
            db.owners.insert({"name":"Jim"})
            db.tasks.insert([
                {"name": "read"},
                {"name": "sleep"}
                ])

            # update jim with tasks: reading and sleeping
            reading_task = db.tasks.find_one({"name": "read"})
            sleeping_task = db.tasks.find_one({"name": "sleep"})

            jim_update = db.owners.find_one({"name": "Jim"})
            jim_update["tasks"] = [
                DBRef(collection = "tasks", id = reading_task["_id"]),
                DBRef(collection = "tasks", id = sleeping_task["_id"])
                ]

            db.owners.save(jim_update)

            # get jim fresh again and display his tasks
            fresh_jim = db.owners.find_one({"name":"Jim"})
            print "tasks are:"
            for task in fresh_jim["tasks"]:
                print db.dereference(task)["name"]

            db.tasks_ref.insert( { "ref" :  DBRef(collection = "tasks", id = reading_task["_id"]) })
            db.tasks_ref.insert( { "ref" :  DBRef(collection = "tasks", id = sleeping_task["_id"]) })
            r1 = db.tasks_ref.find( { "ref" : DBRef(collection = "tasks", id = reading_task["_id"]) })
            print r1.count()
        finally:
            db.drop_collection('owners')
            db.drop_collection('tasks')
            db.drop_collection('tasks_ref')
Esempio n. 17
0
class TestCursor(unittest.TestCase):

    def setUp(self):
        self.client = get_client()
        self.db = Database(self.client, "pymongo_test")

    def tearDown(self):
        self.db = None

    def test_max_time_ms(self):
        if not version.at_least(self.db.connection, (2, 5, 3, -1)):
            raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")

        db = self.db
        db.pymongo_test.drop()
        coll = db.pymongo_test
        self.assertRaises(TypeError, coll.find().max_time_ms, 'foo')
        coll.insert({"amalia": 1})
        coll.insert({"amalia": 2})

        coll.find().max_time_ms(None)
        coll.find().max_time_ms(1L)

        cursor = coll.find().max_time_ms(999)
        self.assertEqual(999, cursor._Cursor__max_time_ms)
        cursor = coll.find().max_time_ms(10).max_time_ms(1000)
        self.assertEqual(1000, cursor._Cursor__max_time_ms)

        cursor = coll.find().max_time_ms(999)
        c2 = cursor.clone()
        self.assertEqual(999, c2._Cursor__max_time_ms)
        self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec())
        self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec())

        self.assertTrue(coll.find_one(max_time_ms=1000))

        if "enableTestCommands=1" in get_command_line(self.client)["argv"]:
            # Cursor parses server timeout error in response to initial query.
            self.client.admin.command("configureFailPoint",
                                      "maxTimeAlwaysTimeOut",
                                      mode="alwaysOn")
            try:
                cursor = coll.find().max_time_ms(1)
                try:
                    cursor.next()
                except ExecutionTimeout:
                    pass
                else:
                    self.fail("ExecutionTimeout not raised")
                self.assertRaises(ExecutionTimeout,
                                  coll.find_one, max_time_ms=1)
            finally:
                self.client.admin.command("configureFailPoint",
                                          "maxTimeAlwaysTimeOut",
                                          mode="off")

    def test_max_time_ms_getmore(self):
        # Test that Cursor handles server timeout error in response to getmore.
        if "enableTestCommands=1" not in get_command_line(self.client)["argv"]:
            raise SkipTest("Need test commands enabled")

        if not version.at_least(self.db.connection, (2, 5, 3, -1)):
            raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")

        coll = self.db.pymongo_test
        coll.insert({} for _ in range(200))
        cursor = coll.find().max_time_ms(100)

        # Send initial query before turning on failpoint.
        cursor.next()
        self.client.admin.command("configureFailPoint",
                                  "maxTimeAlwaysTimeOut",
                                  mode="alwaysOn")
        try:
            try:
                # Iterate up to first getmore.
                list(cursor)
            except ExecutionTimeout:
                pass
            else:
                self.fail("ExecutionTimeout not raised")
        finally:
            self.client.admin.command("configureFailPoint",
                                      "maxTimeAlwaysTimeOut",
                                      mode="off")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assertTrue("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("num", ASCENDING)]).explain)
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(db.test.find({}).hint(spec).explain()["cursor"],
                         "BtreeCursor %s" % index)
        self.assertEqual(db.test.find({}).hint(spec).hint(None)
                         .explain()["cursor"],
                         "BasicCursor")
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)
        self.assertTrue(db.test.find().limit(5L))

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_max(self):
        db = self.db
        db.test.drop()
        db.test.ensure_index([("j", ASCENDING)])

        for j in range(10):
            db.test.insert({"j": j, "k": j})

        cursor = db.test.find().max([("j", 3)])
        self.assertEqual(len(list(cursor)), 3)

        # Tuple.
        cursor = db.test.find().max((("j", 3), ))
        self.assertEqual(len(list(cursor)), 3)

        # Compound index.
        db.test.ensure_index([("j", ASCENDING), ("k", ASCENDING)])
        cursor = db.test.find().max([("j", 3), ("k", 3)])
        self.assertEqual(len(list(cursor)), 3)

        # Wrong order.
        cursor = db.test.find().max([("k", 3), ("j", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        # No such index.
        cursor = db.test.find().max([("k", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        self.assertRaises(TypeError, db.test.find().max, 10)
        self.assertRaises(TypeError, db.test.find().max, {"j": 10})

    def test_min(self):
        db = self.db
        db.test.drop()
        db.test.ensure_index([("j", ASCENDING)])

        for j in range(10):
            db.test.insert({"j": j, "k": j})

        cursor = db.test.find().min([("j", 3)])
        self.assertEqual(len(list(cursor)), 7)

        # Tuple.
        cursor = db.test.find().min((("j", 3), ))
        self.assertEqual(len(list(cursor)), 7)

        # Compound index.
        db.test.ensure_index([("j", ASCENDING), ("k", ASCENDING)])
        cursor = db.test.find().min([("j", 3), ("k", 3)])
        self.assertEqual(len(list(cursor)), 7)

        # Wrong order.
        cursor = db.test.find().min([("k", 3), ("j", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        # No such index.
        cursor = db.test.find().min([("k", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        self.assertRaises(TypeError, db.test.find().min, 10)
        self.assertRaises(TypeError, db.test.find().min, {"j": 10})

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        self.assertTrue(db.test.find().batch_size(5L))
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_limit_and_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(500):
            db.test.save({"x": x})

        curs = db.test.find().limit(0).batch_size(10)
        curs.next()
        self.assertEqual(10, curs._Cursor__retrieved)

        curs = db.test.find().limit(-2).batch_size(0)
        curs.next()
        self.assertEqual(2, curs._Cursor__retrieved)

        curs = db.test.find().limit(-4).batch_size(5)
        curs.next()
        self.assertEqual(4, curs._Cursor__retrieved)

        curs = db.test.find().limit(50).batch_size(500)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        curs = db.test.find().batch_size(500)
        curs.next()
        self.assertEqual(500, curs._Cursor__retrieved)

        curs = db.test.find().limit(50)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        # these two might be shaky, as the default
        # is set by the server. as of 2.0.0-rc0, 101
        # or 1MB (whichever is smaller) is default
        # for queries without ntoreturn
        curs = db.test.find()
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

        curs = db.test.find().limit(0).batch_size(0)
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)
        self.assertRaises(ValueError, db.test.find().skip, -5)
        self.assertTrue(db.test.find().skip(5L))

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError, db.test.find().sort,
                          [("hello", DESCENDING)], DESCENDING)

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in
                db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [(i["a"], i["b"]) for i in
                  db.test.find().sort([("b", DESCENDING),
                                       ("a", ASCENDING)])]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assertTrue(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(3, len(list(db.test.find().where(Code('this.x < i',
                                                               {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in
                          db.test.find().where('this.x < 3')])
        self.assertEqual([],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x < 3')])
        self.assertEqual([5],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

        class MyClass(dict):
            pass

        cursor = self.db.test.find(as_class=MyClass)
        for e in cursor:
            self.assertEqual(type(MyClass()), type(e))
        cursor = self.db.test.find(as_class=MyClass)
        self.assertEqual(type(MyClass()), type(cursor[0]))

        # Just test attributes
        cursor = self.db.test.find({"x": re.compile("^hello.*")},
                                   skip=1,
                                   timeout=False,
                                   snapshot=True,
                                   tailable=True,
                                   as_class=MyClass,
                                   slave_okay=True,
                                   await_data=True,
                                   partial=True,
                                   manipulate=False,
                                   compile_re=False,
                                   fields={'_id': False}).limit(2)
        cursor.min([('a', 1)]).max([('b', 3)])
        cursor.add_option(128)
        cursor.comment('hi!')

        cursor2 = cursor.clone()
        self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
        self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
        self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
        self.assertEqual(type(cursor._Cursor__as_class),
                         type(cursor2._Cursor__as_class))
        self.assertEqual(cursor._Cursor__slave_okay,
                         cursor2._Cursor__slave_okay)
        self.assertEqual(cursor._Cursor__manipulate,
                         cursor2._Cursor__manipulate)
        self.assertEqual(cursor._Cursor__compile_re,
                         cursor2._Cursor__compile_re)
        self.assertEqual(cursor._Cursor__query_flags,
                         cursor2._Cursor__query_flags)
        self.assertEqual(cursor._Cursor__comment,
                         cursor2._Cursor__comment)
        self.assertEqual(cursor._Cursor__min,
                         cursor2._Cursor__min)
        self.assertEqual(cursor._Cursor__max,
                         cursor2._Cursor__max)

        # Shallow copies can so can mutate
        cursor2 = copy.copy(cursor)
        cursor2._Cursor__fields['cursor2'] = False
        self.assertTrue('cursor2' in cursor._Cursor__fields)

        # Deepcopies and shouldn't mutate
        cursor3 = copy.deepcopy(cursor)
        cursor3._Cursor__fields['cursor3'] = False
        self.assertFalse('cursor3' in cursor._Cursor__fields)

        cursor4 = cursor.clone()
        cursor4._Cursor__fields['cursor4'] = False
        self.assertFalse('cursor4' in cursor._Cursor__fields)

        # Test memo when deepcopying queries
        query = {"hello": "world"}
        query["reflexive"] = query
        cursor = self.db.test.find(query)

        cursor2 = copy.deepcopy(cursor)

        self.assertNotEqual(id(cursor._Cursor__spec),
                            id(cursor2._Cursor__spec))
        self.assertEqual(id(cursor2._Cursor__spec['reflexive']),
                         id(cursor2._Cursor__spec))
        self.assertEqual(len(cursor2._Cursor__spec), 2)

        # Ensure hints are cloned as the correct type
        cursor = self.db.test.find().hint([('z', 1), ("a", 1)])
        cursor2 = copy.deepcopy(cursor)
        self.assertTrue(isinstance(cursor2._Cursor__hint, SON))
        self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint)

    def test_deepcopy_cursor_littered_with_regexes(self):

        cursor = self.db.test.find({"x": re.compile("^hmmm.*"),
                                    "y": [re.compile("^hmm.*")],
                                    "z": {"a": [re.compile("^hm.*")]},
                                    re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})

        cursor2 = copy.deepcopy(cursor)
        self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)

    def test_add_remove_option(self):
        cursor = self.db.test.find()
        self.assertEqual(0, cursor._Cursor__query_options())
        cursor.add_option(2)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(32)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(128)
        cursor2 = self.db.test.find(tailable=True,
                                    await_data=True).add_option(128)
        self.assertEqual(162, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(162, cursor._Cursor__query_options())
        cursor.add_option(128)
        self.assertEqual(162, cursor._Cursor__query_options())

        cursor.remove_option(128)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(2, cursor._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

        # Slave OK
        cursor = self.db.test.find(slave_okay=True)
        self.assertEqual(4, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(4)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        self.assertTrue(cursor._Cursor__slave_okay)
        cursor.remove_option(4)
        self.assertEqual(0, cursor._Cursor__query_options())
        self.assertFalse(cursor._Cursor__slave_okay)

        # Timeout
        cursor = self.db.test.find(timeout=False)
        self.assertEqual(16, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(16)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(16)
        self.assertEqual(0, cursor._Cursor__query_options())

        # Tailable / Await data
        cursor = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(34)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

        # Exhaust - which mongos doesn't support
        if not is_mongos(self.db.connection):
            cursor = self.db.test.find(exhaust=True)
            self.assertEqual(64, cursor._Cursor__query_options())
            cursor2 = self.db.test.find().add_option(64)
            self.assertEqual(cursor._Cursor__query_options(),
                             cursor2._Cursor__query_options())
            self.assertTrue(cursor._Cursor__exhaust)
            cursor.remove_option(64)
            self.assertEqual(0, cursor._Cursor__query_options())
            self.assertFalse(cursor._Cursor__exhaust)

        # Partial
        cursor = self.db.test.find(partial=True)
        self.assertEqual(128, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(128)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(128)
        self.assertEqual(0, cursor._Cursor__query_options())

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in zip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in zip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in zip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in zip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in zip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in zip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find()[40:45].limit(0).skip(20))
                            )
                        )
        for a, b in zip(count(20),
                         self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find().limit(10).skip(40)[20:]))
                        )
        for a, b in zip(count(20),
                         self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(80,
                         len(list(self.db.test.find()[10:10].limit(0).skip(20))
                            )
                        )

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4")

        self.assertRaises(TypeError, self.db.test.find().count, "foo")

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_get_more(self):
        db = self.db
        db.drop_collection("test")
        db.test.insert([{'i': i} for i in range(10)])
        self.assertEqual(10, len(list(db.test.find().batch_size(5))))

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000, max=3)

        try:
            cursor = db.test.find(tailable=True)

            db.test.insert({"x": 1})
            count = 0
            for doc in cursor:
                count += 1
                self.assertEqual(1, doc["x"])
            self.assertEqual(1, count)

            db.test.insert({"x": 2})
            count = 0
            for doc in cursor:
                count += 1
                self.assertEqual(2, doc["x"])
            self.assertEqual(1, count)

            db.test.insert({"x": 3})
            count = 0
            for doc in cursor:
                count += 1
                self.assertEqual(3, doc["x"])
            self.assertEqual(1, count)

            # Capped rollover - the collection can never
            # have more than 3 documents. Just make sure
            # this doesn't raise...
            db.test.insert(({"x": i} for i in xrange(4, 7)))
            self.assertEqual(0, len(list(cursor)))

            # and that the cursor doesn't think it's still alive.
            self.assertFalse(cursor.alive)

            self.assertEqual(3, db.test.count())
        finally:
            db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest("distinct with query requires MongoDB >= 1.1.3")

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest("maxScan requires MongoDB >= 1.5.1")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(50, len(list(self.db.test.find()
                                      .max_scan(90).max_scan(50))))

    def test_with_statement(self):
        if sys.version_info < (2, 6):
            raise SkipTest("With statement requires Python >= 2.6")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        c1 = self.db.test.find()
        exec """
with self.db.test.find() as c2:
    self.assertTrue(c2.alive)
self.assertFalse(c2.alive)

with self.db.test.find() as c2:
    self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
        self.assertTrue(c1.alive)

    def test_comment(self):
        if is_mongos(self.client):
            raise SkipTest("profile is not supported by mongos")
        if not version.at_least(self.db.connection, (2, 0)):
            raise SkipTest("Requires server >= 2.0")
        if server_started_with_auth(self.db.connection):
            raise SkipTest("SERVER-4754 - This test uses profiling.")

        def run_with_profiling(func):
            self.db.set_profiling_level(OFF)
            self.db.system.profile.drop()
            self.db.set_profiling_level(ALL)
            func()
            self.db.set_profiling_level(OFF)

        def find():
            list(self.db.test.find().comment('foo'))
            op = self.db.system.profile.find({'ns': 'pymongo_test.test',
                                              'op': 'query',
                                              'query.$comment': 'foo'})
            self.assertEqual(op.count(), 1)

        run_with_profiling(find)

        def count():
            self.db.test.find().comment('foo').count()
            op = self.db.system.profile.find({'ns': 'pymongo_test.$cmd',
                                              'op': 'command',
                                              'command.count': 'test',
                                              'command.$comment': 'foo'})
            self.assertEqual(op.count(), 1)

        run_with_profiling(count)

        def distinct():
            self.db.test.find().comment('foo').distinct('type')
            op = self.db.system.profile.find({'ns': 'pymongo_test.$cmd',
                                              'op': 'command',
                                              'command.distinct': 'test',
                                              'command.$comment': 'foo'})
            self.assertEqual(op.count(), 1)

        run_with_profiling(distinct)

        self.db.test.insert([{}, {}])
        cursor = self.db.test.find()
        cursor.next()
        self.assertRaises(InvalidOperation, cursor.comment, 'hello')

        self.db.system.profile.drop()

    def test_cursor_transfer(self):

        # This is just a test, don't try this at home...
        self.db.test.remove({})
        self.db.test.insert({'_id': i} for i in xrange(200))

        class CManager(CursorManager):
            def __init__(self, connection):
                super(CManager, self).__init__(connection)

            def close(self, dummy):
                # Do absolutely nothing...
                pass

        client = self.db.connection
        ctx = catch_warnings()
        try:
            warnings.simplefilter("ignore", DeprecationWarning)
            client.set_cursor_manager(CManager)

            docs = []
            cursor = self.db.test.find().batch_size(10)
            docs.append(cursor.next())
            cursor.close()
            docs.extend(cursor)
            self.assertEqual(len(docs), 10)
            cmd_cursor = {'id': cursor.cursor_id, 'firstBatch': []}
            ccursor = CommandCursor(cursor.collection, cmd_cursor,
                                    cursor.conn_id, retrieved=cursor.retrieved)
            docs.extend(ccursor)
            self.assertEqual(len(docs), 200)
        finally:
            client.set_cursor_manager(CursorManager)
            ctx.exit()
Esempio n. 18
0
class TestCursor(unittest.TestCase):

    def setUp(self):
        self.db = Database(get_connection(), "pymongo_test")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assert_("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("num", ASCENDING)]).explain)
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(db.test.find({}).hint(spec).explain()["cursor"],
                         "BtreeCursor %s" % index)
        self.assertEqual(db.test.find({}).hint(spec).hint(None)
                         .explain()["cursor"],
                         "BasicCursor")
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError, db.test.find().sort,
                          [("hello", DESCENDING)], DESCENDING)
        self.assertRaises(TypeError, db.test.find().sort, "hello", "world")

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in
                db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [(i["a"], i["b"]) for i in
                  db.test.find().sort([("b", DESCENDING),
                                       ("a", ASCENDING)])]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assert_(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(3, len(list(db.test.find().where(Code('this.x < i',
                                                               {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in
                          db.test.find().where('this.x < 3')])
        self.assertEqual([],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x < 3')])
        self.assertEqual([5],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_kill_cursors(self):
        db = self.db
        db.drop_collection("test")

        c = db.command("cursorInfo")["clientCursors_size"]

        test = db.test
        for i in range(10000):
            test.insert({"i": i})
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        for _ in range(10):
            db.test.find_one()
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        for _ in range(10):
            for x in db.test.find():
                break
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        a = db.test.find()
        for x in a:
            break
        self.assertNotEqual(c, db.command("cursorInfo")["clientCursors_size"])

        del a
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        a = db.test.find().limit(10)
        for x in a:
            break
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

        class MyClass(dict):
            pass

        cursor = self.db.test.find(as_class=MyClass)
        for e in cursor:
            self.assertEqual(type(MyClass()), type(e))
        cursor = self.db.test.find(as_class=MyClass)
        self.assertEqual(type(MyClass()), type(cursor[0]))

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        izip = itertools.izip
        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in izip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in izip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in izip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in izip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in izip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in izip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find()[40:45].limit(0).skip(20))
                            )
                        )
        for a, b in izip(count(20),
                         self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find().limit(10).skip(40)[20:]))
                        )
        for a, b in izip(count(20),
                         self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(80,
                         len(list(self.db.test.find()[10:10].limit(0).skip(20))
                            )
                        )

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest()

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000)

        cursor = db.test.find(tailable=True)

        db.test.insert({"x": 1})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(1, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 2})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(2, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 3})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(3, doc["x"])
        self.assertEqual(1, count)

        self.assertEqual(3, db.test.count())
        db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest()

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest()

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(50, len(list(self.db.test.find()
                                      .max_scan(90).max_scan(50))))

    def test_with_statement(self):
        if sys.version_info < (2, 6):
            raise SkipTest()

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        c1  = self.db.test.find()
        exec """
with self.db.test.find() as c2:
    self.assertTrue(c2.alive)
self.assertFalse(c2.alive)

with self.db.test.find() as c2:
    self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
        self.assertTrue(c1.alive)
Esempio n. 19
0
class TestCursor(unittest.TestCase):
    def setUp(self):
        self.db = Database(get_connection(), "pymongo_test")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assert_("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.remove({})
        db.test.drop_indexes()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(OperationFailure, db.test.find({"num": 17, "foo": 17}).hint([("num", ASCENDING)]).explain)
        self.assertRaises(OperationFailure, db.test.find({"num": 17, "foo": 17}).hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(db.test.find({}).hint(spec).explain()["cursor"], "BtreeCursor %s" % index)
        self.assertEqual(db.test.find({}).hint(spec).hint(None).explain()["cursor"], "BasicCursor")
        self.assertRaises(OperationFailure, db.test.find({"num": 17, "foo": 17}).hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    # This is deprecated - test that a warning is actually raised
    def test_slave_okay(self):
        db = self.db
        db.drop_collection("test")
        db.test.save({"x": 1})

        warnings.simplefilter("error")

        self.assertEqual(1, db.test.find().next()["x"])
        self.assertRaises(DeprecationWarning, db.test.find, slave_okay=True)
        self.assertRaises(DeprecationWarning, db.test.find, slave_okay=False)

        warnings.simplefilter("default")

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)

        db.test.remove({})
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError, db.test.find().sort, [("hello", DESCENDING)], DESCENDING)
        self.assertRaises(TypeError, db.test.find().sort, "hello", "world")

        db.test.remove({})

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.remove({})
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [(i["a"], i["b"]) for i in db.test.find().sort([("b", DESCENDING), ("a", ASCENDING)])]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.remove({})

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assert_(isinstance(db.test.find().count(), types.IntType))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.remove({})

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where("this.x < 3"))))
        self.assertEqual(3, len(list(db.test.find().where(Code("this.x < 3")))))
        self.assertEqual(3, len(list(db.test.find().where(Code("this.x < i", {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where("this.x < 3").count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u"this.x < 3").count())
        self.assertEqual([0, 1, 2], [a["x"] for a in db.test.find().where("this.x < 3")])
        self.assertEqual([], [a["x"] for a in db.test.find({"x": 5}).where("this.x < 3")])
        self.assertEqual([5], [a["x"] for a in db.test.find({"x": 5}).where("this.x > 3")])

        cursor = db.test.find().where("this.x < 3").where("this.x > 7")
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where("this.x > 3")
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, "this.x < 3")

    def test_kill_cursors(self):
        db = self.db
        db.drop_collection("test")

        client_cursors = db._command({"cursorInfo": 1})["clientCursors_size"]
        by_location = db._command({"cursorInfo": 1})["byLocation_size"]

        for i in range(10000):
            db.test.insert({"i": i})

        self.assertEqual(client_cursors, db._command({"cursorInfo": 1})["clientCursors_size"])
        self.assertEqual(by_location, db._command({"cursorInfo": 1})["byLocation_size"])

        for _ in range(10):
            db.test.find_one()

        self.assertEqual(client_cursors, db._command({"cursorInfo": 1})["clientCursors_size"])
        self.assertEqual(by_location, db._command({"cursorInfo": 1})["byLocation_size"])

        for _ in range(10):
            for x in db.test.find():
                break

        self.assertEqual(client_cursors, db._command({"cursorInfo": 1})["clientCursors_size"])
        self.assertEqual(by_location, db._command({"cursorInfo": 1})["byLocation_size"])

        a = db.test.find()
        for x in a:
            break

        self.assertNotEqual(client_cursors, db._command({"cursorInfo": 1})["clientCursors_size"])
        self.assertNotEqual(by_location, db._command({"cursorInfo": 1})["byLocation_size"])

        del a

        self.assertEqual(client_cursors, db._command({"cursorInfo": 1})["clientCursors_size"])
        self.assertEqual(by_location, db._command({"cursorInfo": 1})["byLocation_size"])

        a = db.test.find().limit(10)
        for x in a:
            break

        self.assertEqual(client_cursors, db._command({"cursorInfo": 1})["clientCursors_size"])
        self.assertEqual(by_location, db._command({"cursorInfo": 1})["byLocation_size"])

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

    def test_count_with_fields(self):
        self.db.test.remove({})
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection(), (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        izip = itertools.izip
        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in izip(count(0), self.db.test.find()):
            self.assertEqual(a, b["i"])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in izip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b["i"])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in izip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b["i"])

        for a, b in izip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b["i"])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in izip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b["i"])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in izip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b["i"])

        self.assertEqual(80, len(list(self.db.test.find()[40:45].limit(0).skip(20))))
        for a, b in izip(count(20), self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b["i"])

        self.assertEqual(80, len(list(self.db.test.find().limit(10).skip(40)[20:])))
        for a, b in izip(count(20), self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b["i"])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])
        self.assertRaises(IndexError, lambda: self.db.test.find()[10:10])
        self.assertRaises(IndexError, lambda: self.db.test.find()[:0])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]["i"])
        self.assertEqual(50, self.db.test.find()[50]["i"])
        self.assertEqual(50, self.db.test.find().skip(50)[0]["i"])
        self.assertEqual(50, self.db.test.find().skip(49)[1]["i"])
        self.assertEqual(50, self.db.test.find()[50L]["i"])
        self.assertEqual(99, self.db.test.find()[99]["i"])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError, lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection(), (1, 1, 4, -1)):
            raise SkipTest()

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")

        cursor = db.test.find(tailable=True)

        db.test.insert({"x": 1})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(1, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 2})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(2, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 3})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(3, doc["x"])
        self.assertEqual(1, count)

        self.assertEqual(3, db.test.count())

    def test_distinct(self):
        if not version.at_least(self.db.connection(), (1, 1, 3, 1)):
            raise SkipTest()

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)
Esempio n. 20
0
class TestCursor(unittest.TestCase):

    def setUp(self):
        self.db = Database(get_connection(), "pymongo_test")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assert_("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("num", ASCENDING)]).explain)
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(db.test.find({}).hint(spec).explain()["cursor"],
                         "BtreeCursor %s" % index)
        self.assertEqual(db.test.find({}).hint(spec).hint(None)
                         .explain()["cursor"],
                         "BasicCursor")
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_limit_and_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(500):
            db.test.save({"x": x})

        curs = db.test.find().limit(0).batch_size(10)
        curs.next()
        self.assertEquals(10, curs._Cursor__retrieved)

        curs = db.test.find().limit(-2).batch_size(0)
        curs.next()
        self.assertEquals(2, curs._Cursor__retrieved)

        curs = db.test.find().limit(-4).batch_size(5)
        curs.next()
        self.assertEquals(4, curs._Cursor__retrieved)

        curs = db.test.find().limit(50).batch_size(500)
        curs.next()
        self.assertEquals(50, curs._Cursor__retrieved)

        curs = db.test.find().batch_size(500)
        curs.next()
        self.assertEquals(500, curs._Cursor__retrieved)

        curs = db.test.find().limit(50)
        curs.next()
        self.assertEquals(50, curs._Cursor__retrieved)

        # these two might be shaky, as the default
        # is set by the server. as of 2.0.0-rc0, 101
        # or 1MB (whichever is smaller) is default
        # for queries without ntoreturn
        curs = db.test.find()
        curs.next()
        self.assertEquals(101, curs._Cursor__retrieved)

        curs = db.test.find().limit(0).batch_size(0)
        curs.next()
        self.assertEquals(101, curs._Cursor__retrieved)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError, db.test.find().sort,
                          [("hello", DESCENDING)], DESCENDING)
        self.assertRaises(TypeError, db.test.find().sort, "hello", "world")

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in
                db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [(i["a"], i["b"]) for i in
                  db.test.find().sort([("b", DESCENDING),
                                       ("a", ASCENDING)])]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assert_(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(3, len(list(db.test.find().where(Code('this.x < i',
                                                               {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in
                          db.test.find().where('this.x < 3')])
        self.assertEqual([],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x < 3')])
        self.assertEqual([5],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_kill_cursors_implicit(self):
        # Only CPython does reference counting garbage collection.
        if (sys.platform.startswith('java') or
            sys.platform == 'cli' or
            'PyPy' in sys.version):
            raise SkipTest()

        db = self.db
        db.drop_collection("test")

        c = db.command("cursorInfo")["clientCursors_size"]

        test = db.test
        for i in range(10000):
            test.insert({"i": i})
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        # Automatically closed by the server (limit == -1).
        for _ in range(10):
            db.test.find_one()
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        for _ in range(10):
            for x in db.test.find():
                break
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        a = db.test.find()
        for x in a:
            break
        self.assertNotEqual(c, db.command("cursorInfo")["clientCursors_size"])

        # Explicitly close (won't work with PyPy and Jython).
        del a
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        # Automatically closed by the server since the entire
        # result was returned.
        a = db.test.find().limit(10)
        for x in a:
            break
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

    def test_kill_cursors_explicit(self):
        db = self.db
        db.drop_collection("test")

        c = db.command("cursorInfo")["clientCursors_size"]

        test = db.test
        for i in range(10000):
            test.insert({"i": i})
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        # Automatically closed by the server (limit == -1).
        for _ in range(10):
            db.test.find_one()
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        a = db.test.find()
        for x in a:
            break
        self.assertNotEqual(c, db.command("cursorInfo")["clientCursors_size"])

        # Explicitly close (should work with all interpreter implementations).
        a.close()
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        # Automatically closed by the server since the entire
        # result was returned.
        a = db.test.find().limit(10)
        for x in a:
            break
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

        class MyClass(dict):
            pass

        cursor = self.db.test.find(as_class=MyClass)
        for e in cursor:
            self.assertEqual(type(MyClass()), type(e))
        cursor = self.db.test.find(as_class=MyClass)
        self.assertEqual(type(MyClass()), type(cursor[0]))

        # Just test attributes
        cursor = self.db.test.find(skip=1,
                                   timeout=False,
                                   snapshot=True,
                                   tailable=True,
                                   as_class=MyClass,
                                   slave_okay=True,
                                   await_data=True,
                                   partial=True,
                                   manipulate=False).limit(2)
        cursor.add_option(64)

        cursor2 = cursor.clone()
        self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
        self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
        self.assertEqual(cursor._Cursor__timeout, cursor2._Cursor__timeout)
        self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
        self.assertEqual(cursor._Cursor__tailable, cursor2._Cursor__tailable)
        self.assertEqual(type(cursor._Cursor__as_class),
                         type(cursor2._Cursor__as_class))
        self.assertEqual(cursor._Cursor__slave_okay,
                         cursor2._Cursor__slave_okay)
        self.assertEqual(cursor._Cursor__await_data,
                         cursor2._Cursor__await_data)
        self.assertEqual(cursor._Cursor__partial, cursor2._Cursor__partial)
        self.assertEqual(cursor._Cursor__manipulate,
                         cursor2._Cursor__manipulate)
        self.assertEqual(cursor._Cursor__query_flags,
                         cursor2._Cursor__query_flags)

    def test_add_remove_option(self):
        cursor = self.db.test.find()
        self.assertEqual(0, cursor._Cursor__query_options())
        cursor.add_option(2)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(32)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(128)
        cursor2 = self.db.test.find(tailable=True,
                                    await_data=True).add_option(128)
        self.assertEqual(162, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(162, cursor._Cursor__query_options())
        cursor.add_option(128)
        self.assertEqual(162, cursor._Cursor__query_options())

        cursor.remove_option(128)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(2, cursor._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        izip = itertools.izip
        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in izip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in izip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in izip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in izip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in izip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in izip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find()[40:45].limit(0).skip(20))
                            )
                        )
        for a, b in izip(count(20),
                         self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find().limit(10).skip(40)[20:]))
                        )
        for a, b in izip(count(20),
                         self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(80,
                         len(list(self.db.test.find()[10:10].limit(0).skip(20))
                            )
                        )

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest()

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_get_more(self):
        db = self.db
        db.drop_collection("test")
        db.test.insert([{'i': i} for i in range(10)])
        self.assertEqual(10, len(list(db.test.find().batch_size(5))))

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000)

        cursor = db.test.find(tailable=True)

        db.test.insert({"x": 1})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(1, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 2})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(2, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 3})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(3, doc["x"])
        self.assertEqual(1, count)

        self.assertEqual(3, db.test.count())
        db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest()

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest()

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(50, len(list(self.db.test.find()
                                      .max_scan(90).max_scan(50))))

    def test_with_statement(self):
        if sys.version_info < (2, 6):
            raise SkipTest()

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        c1 = self.db.test.find()
        exec """
with self.db.test.find() as c2:
    self.assertTrue(c2.alive)
self.assertFalse(c2.alive)

with self.db.test.find() as c2:
    self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
        self.assertTrue(c1.alive)
Esempio n. 21
0
class TestCursor(unittest.TestCase):
    def setUp(self):
        self.db = Database(get_connection(), "pymongo_test")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assert_("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(
            OperationFailure,
            db.test.find({
                "num": 17,
                "foo": 17
            }).hint([("num", ASCENDING)]).explain)
        self.assertRaises(
            OperationFailure,
            db.test.find({
                "num": 17,
                "foo": 17
            }).hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(
            db.test.find({}).hint(spec).explain()["cursor"],
            "BtreeCursor %s" % index)
        self.assertEqual(
            db.test.find({}).hint(spec).hint(None).explain()["cursor"],
            "BasicCursor")
        self.assertRaises(
            OperationFailure,
            db.test.find({
                "num": 17,
                "foo": 17
            }).hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError,
                          db.test.find().sort, [("hello", DESCENDING)],
                          DESCENDING)
        self.assertRaises(TypeError, db.test.find().sort, "hello", "world")

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [
            i["x"]
            for i in db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)
        ]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [
            (i["a"], i["b"])
            for i in db.test.find().sort([("b", DESCENDING), ("a", ASCENDING)])
        ]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assert_(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(
            3, len(list(db.test.find().where(Code('this.x < i', {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in db.test.find().where('this.x < 3')])
        self.assertEqual(
            [], [a["x"] for a in db.test.find({
                "x": 5
            }).where('this.x < 3')])
        self.assertEqual(
            [5], [a["x"] for a in db.test.find({
                "x": 5
            }).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_kill_cursors(self):
        db = self.db
        db.drop_collection("test")

        c = db.command("cursorInfo")["clientCursors_size"]

        test = db.test
        for i in range(10000):
            test.insert({"i": i})
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        for _ in range(10):
            db.test.find_one()
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        for _ in range(10):
            for x in db.test.find():
                break
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        a = db.test.find()
        for x in a:
            break
        self.assertNotEqual(c, db.command("cursorInfo")["clientCursors_size"])

        del a
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

        a = db.test.find().limit(10)
        for x in a:
            break
        self.assertEqual(c, db.command("cursorInfo")["clientCursors_size"])

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        izip = itertools.izip
        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in izip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in izip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in izip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in izip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in izip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in izip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(
            80, len(list(self.db.test.find()[40:45].limit(0).skip(20))))
        for a, b in izip(count(20),
                         self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(
            80, len(list(self.db.test.find().limit(10).skip(40)[20:])))
        for a, b in izip(count(20),
                         self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(
            80, len(list(self.db.test.find()[10:10].limit(0).skip(20))))

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest()

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000)

        cursor = db.test.find(tailable=True)

        db.test.insert({"x": 1})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(1, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 2})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(2, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 3})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(3, doc["x"])
        self.assertEqual(1, count)

        self.assertEqual(3, db.test.count())
        db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest()

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest()

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(
            50, len(list(self.db.test.find().max_scan(90).max_scan(50))))
Esempio n. 22
0
class TestCursor(unittest.TestCase):

    def setUp(self):
        self.client = get_client()
        self.db = Database(self.client, "pymongo_test")

    def tearDown(self):
        self.db = None

    def test_max_time_ms(self):
        if not version.at_least(self.db.connection, (2, 5, 3, -1)):
            raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")

        max_time_ms_response = {
            '$err': 'operation exceeded time limit',
            'code': 50
        }
        bson_response = BSON.encode(max_time_ms_response)
        response_flags = pack("<i", 2)
        cursor_id = pack("<q", 0)
        starting_from = pack("<i", 0)
        number_returned = pack("<i", 1)
        op_reply = (response_flags + cursor_id + starting_from +
                    number_returned + bson_response)
        self.assertRaises(ExecutionTimeout, _unpack_response,
                          op_reply)

        command_response = {
            'ok': 0,
            'errmsg': 'operation exceeded time limit',
            'code': 50
        }
        self.assertRaises(ExecutionTimeout, _check_command_response,
                          command_response, None)

        db = self.db
        db.pymongo_test.drop()
        coll = db.pymongo_test
        self.assertRaises(TypeError, coll.find().max_time_ms, 'foo')
        coll.insert({"amalia": 1})
        coll.insert({"amalia": 2})

        coll.find().max_time_ms(None)
        coll.find().max_time_ms(1L)

        cursor = coll.find().max_time_ms(999)
        self.assertEqual(999, cursor._Cursor__max_time_ms)
        cursor = coll.find().max_time_ms(10).max_time_ms(1000)
        self.assertEqual(1000, cursor._Cursor__max_time_ms)

        cursor = coll.find().max_time_ms(999)
        c2 = cursor.clone()
        self.assertEqual(999, c2._Cursor__max_time_ms)
        self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec())
        self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec())

        self.assertTrue(coll.find_one(max_time_ms=1000))

        reducer = Code("""function(obj, prev){prev.count++;}""")
        coll.group(key={"amalia": 1}, condition={}, initial={"count": 0},
                   reduce=reducer, maxTimeMS=1000)

        if "enableTestCommands=1" in get_command_line(self.client):
            self.client.admin.command("configureFailPoint",
                                      "maxTimeAlwaysTimeOut",
                                      mode="alwaysOn")
            self.assertRaises(ExecutionTimeout,
                              coll.find_one, max_time_ms=1)
            self.client.admin.command("configureFailPoint",
                                      "maxTimeAlwaysTimeOut",
                                      mode="off")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assertTrue("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("num", ASCENDING)]).explain)
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(db.test.find({}).hint(spec).explain()["cursor"],
                         "BtreeCursor %s" % index)
        self.assertEqual(db.test.find({}).hint(spec).hint(None)
                         .explain()["cursor"],
                         "BasicCursor")
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)
        self.assertTrue(db.test.find().limit(5L))

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        self.assertTrue(db.test.find().batch_size(5L))
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_limit_and_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(500):
            db.test.save({"x": x})

        curs = db.test.find().limit(0).batch_size(10)
        curs.next()
        self.assertEqual(10, curs._Cursor__retrieved)

        curs = db.test.find().limit(-2).batch_size(0)
        curs.next()
        self.assertEqual(2, curs._Cursor__retrieved)

        curs = db.test.find().limit(-4).batch_size(5)
        curs.next()
        self.assertEqual(4, curs._Cursor__retrieved)

        curs = db.test.find().limit(50).batch_size(500)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        curs = db.test.find().batch_size(500)
        curs.next()
        self.assertEqual(500, curs._Cursor__retrieved)

        curs = db.test.find().limit(50)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        # these two might be shaky, as the default
        # is set by the server. as of 2.0.0-rc0, 101
        # or 1MB (whichever is smaller) is default
        # for queries without ntoreturn
        curs = db.test.find()
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

        curs = db.test.find().limit(0).batch_size(0)
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)
        self.assertRaises(ValueError, db.test.find().skip, -5)
        self.assertTrue(db.test.find().skip(5L))

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError, db.test.find().sort,
                          [("hello", DESCENDING)], DESCENDING)

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in
                db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [(i["a"], i["b"]) for i in
                  db.test.find().sort([("b", DESCENDING),
                                       ("a", ASCENDING)])]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assertTrue(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(3, len(list(db.test.find().where(Code('this.x < i',
                                                               {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in
                          db.test.find().where('this.x < 3')])
        self.assertEqual([],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x < 3')])
        self.assertEqual([5],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

        class MyClass(dict):
            pass

        cursor = self.db.test.find(as_class=MyClass)
        for e in cursor:
            self.assertEqual(type(MyClass()), type(e))
        cursor = self.db.test.find(as_class=MyClass)
        self.assertEqual(type(MyClass()), type(cursor[0]))

        # Just test attributes
        cursor = self.db.test.find({"x": re.compile("^hello.*")},
                                   skip=1,
                                   timeout=False,
                                   snapshot=True,
                                   tailable=True,
                                   as_class=MyClass,
                                   slave_okay=True,
                                   await_data=True,
                                   partial=True,
                                   manipulate=False,
                                   compile_re=False,
                                   fields={'_id': False}).limit(2)
        cursor.add_option(128)

        cursor2 = cursor.clone()
        self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
        self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
        self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
        self.assertEqual(type(cursor._Cursor__as_class),
                         type(cursor2._Cursor__as_class))
        self.assertEqual(cursor._Cursor__slave_okay,
                         cursor2._Cursor__slave_okay)
        self.assertEqual(cursor._Cursor__manipulate,
                         cursor2._Cursor__manipulate)
        self.assertEqual(cursor._Cursor__compile_re,
                         cursor2._Cursor__compile_re)
        self.assertEqual(cursor._Cursor__query_flags,
                         cursor2._Cursor__query_flags)

        # Shallow copies can so can mutate
        cursor2 = copy.copy(cursor)
        cursor2._Cursor__fields['cursor2'] = False
        self.assertTrue('cursor2' in cursor._Cursor__fields)

        # Deepcopies and shouldn't mutate
        cursor3 = copy.deepcopy(cursor)
        cursor3._Cursor__fields['cursor3'] = False
        self.assertFalse('cursor3' in cursor._Cursor__fields)

        cursor4 = cursor.clone()
        cursor4._Cursor__fields['cursor4'] = False
        self.assertFalse('cursor4' in cursor._Cursor__fields)

        # Test memo when deepcopying queries
        query = {"hello": "world"}
        query["reflexive"] = query
        cursor = self.db.test.find(query)

        cursor2 = copy.deepcopy(cursor)

        self.assertNotEqual(id(cursor._Cursor__spec),
                            id(cursor2._Cursor__spec))
        self.assertEqual(id(cursor2._Cursor__spec['reflexive']),
                         id(cursor2._Cursor__spec))
        self.assertEqual(len(cursor2._Cursor__spec), 2)

        # Ensure hints are cloned as the correct type
        cursor = self.db.test.find().hint([('z', 1), ("a", 1)])
        cursor2 = copy.deepcopy(cursor)
        self.assertTrue(isinstance(cursor2._Cursor__hint, SON))
        self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint)

    def test_deepcopy_cursor_littered_with_regexes(self):

        cursor = self.db.test.find({"x": re.compile("^hmmm.*"),
                                    "y": [re.compile("^hmm.*")],
                                    "z": {"a": [re.compile("^hm.*")]},
                                    re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})

        cursor2 = copy.deepcopy(cursor)
        self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)

    def test_add_remove_option(self):
        cursor = self.db.test.find()
        self.assertEqual(0, cursor._Cursor__query_options())
        cursor.add_option(2)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(32)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(128)
        cursor2 = self.db.test.find(tailable=True,
                                    await_data=True).add_option(128)
        self.assertEqual(162, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(162, cursor._Cursor__query_options())
        cursor.add_option(128)
        self.assertEqual(162, cursor._Cursor__query_options())

        cursor.remove_option(128)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(2, cursor._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

        # Slave OK
        cursor = self.db.test.find(slave_okay=True)
        self.assertEqual(4, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(4)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        self.assertTrue(cursor._Cursor__slave_okay)
        cursor.remove_option(4)
        self.assertEqual(0, cursor._Cursor__query_options())
        self.assertFalse(cursor._Cursor__slave_okay)

        # Timeout
        cursor = self.db.test.find(timeout=False)
        self.assertEqual(16, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(16)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(16)
        self.assertEqual(0, cursor._Cursor__query_options())

        # Tailable / Await data
        cursor = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(34)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

        # Exhaust - which mongos doesn't support
        if not is_mongos(self.db.connection):
            cursor = self.db.test.find(exhaust=True)
            self.assertEqual(64, cursor._Cursor__query_options())
            cursor2 = self.db.test.find().add_option(64)
            self.assertEqual(cursor._Cursor__query_options(),
                             cursor2._Cursor__query_options())
            self.assertTrue(cursor._Cursor__exhaust)
            cursor.remove_option(64)
            self.assertEqual(0, cursor._Cursor__query_options())
            self.assertFalse(cursor._Cursor__exhaust)

        # Partial
        cursor = self.db.test.find(partial=True)
        self.assertEqual(128, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(128)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(128)
        self.assertEqual(0, cursor._Cursor__query_options())

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in zip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in zip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in zip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in zip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in zip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in zip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find()[40:45].limit(0).skip(20))
                            )
                        )
        for a, b in zip(count(20),
                         self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find().limit(10).skip(40)[20:]))
                        )
        for a, b in zip(count(20),
                         self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(80,
                         len(list(self.db.test.find()[10:10].limit(0).skip(20))
                            )
                        )

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4")

        self.assertRaises(TypeError, self.db.test.find().count, "foo")

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_get_more(self):
        db = self.db
        db.drop_collection("test")
        db.test.insert([{'i': i} for i in range(10)])
        self.assertEqual(10, len(list(db.test.find().batch_size(5))))

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000)

        cursor = db.test.find(tailable=True)

        db.test.insert({"x": 1})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(1, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 2})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(2, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 3})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(3, doc["x"])
        self.assertEqual(1, count)

        self.assertEqual(3, db.test.count())
        db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest("distinct with query requires MongoDB >= 1.1.3")

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest("maxScan requires MongoDB >= 1.5.1")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(50, len(list(self.db.test.find()
                                      .max_scan(90).max_scan(50))))

    def test_with_statement(self):
        if sys.version_info < (2, 6):
            raise SkipTest("With statement requires Python >= 2.6")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        c1 = self.db.test.find()
        exec """
with self.db.test.find() as c2:
    self.assertTrue(c2.alive)
self.assertFalse(c2.alive)

with self.db.test.find() as c2:
    self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
        self.assertTrue(c1.alive)
    def test_list_collections(self):
        self.client.drop_database("pymongo_test")
        db = Database(self.client, "pymongo_test")
        db.test.insert_one({"dummy": u"object"})
        db.test.mike.insert_one({"dummy": u"object"})

        results = db.list_collections()
        colls = [result["name"] for result in results]

        # All the collections present.
        self.assertTrue("test" in colls)
        self.assertTrue("test.mike" in colls)

        # No collection containing a '$'.
        for coll in colls:
            self.assertTrue("$" not in coll)

        # Duplicate check.
        coll_cnt = {}
        for coll in colls:
            try:
                # Found duplicate.
                coll_cnt[coll] += 1
                self.assertTrue(False)
            except KeyError:
                coll_cnt[coll] = 1
        coll_cnt = {}

        # Checking if is there any collection which don't exists.
        if (len(set(colls) - set(["test","test.mike"])) == 0 or
            len(set(colls) - set(["test","test.mike","system.indexes"])) == 0):
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        colls = db.list_collections(filter={"name": {"$regex": "^test$"}})
        self.assertEqual(1, len(list(colls)))

        colls = db.list_collections(filter={"name": {"$regex": "^test.mike$"}})
        self.assertEqual(1, len(list(colls)))

        db.drop_collection("test")

        db.create_collection("test", capped=True, size=4096)
        results = db.list_collections(filter={'options.capped': True})
        colls = [result["name"] for result in results]

        # Checking only capped collections are present
        self.assertTrue("test" in colls)
        self.assertFalse("test.mike" in colls)

        # No collection containing a '$'.
        for coll in colls:
            self.assertTrue("$" not in coll)

        # Duplicate check.
        coll_cnt = {}
        for coll in colls:
            try:
                # Found duplicate.
                coll_cnt[coll] += 1
                self.assertTrue(False)
            except KeyError:
                coll_cnt[coll] = 1
        coll_cnt = {}

        # Checking if is there any collection which don't exists.
        if (len(set(colls) - set(["test"])) == 0 or
            len(set(colls) - set(["test","system.indexes"])) == 0):
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        self.client.drop_database("pymongo_test")
    def test_list_collections(self):
        self.client.drop_database("pymongo_test")
        db = Database(self.client, "pymongo_test")
        db.test.insert_one({"dummy": u"object"})
        db.test.mike.insert_one({"dummy": u"object"})

        results = db.list_collections()
        colls = [result["name"] for result in results]

        # All the collections present.
        self.assertTrue("test" in colls)
        self.assertTrue("test.mike" in colls)

        # No collection containing a '$'.
        for coll in colls:
            self.assertTrue("$" not in coll)

        # Duplicate check.
        coll_cnt = {}
        for coll in colls:
            try:
                # Found duplicate.
                coll_cnt[coll] += 1
                self.assertTrue(False)
            except KeyError:
                coll_cnt[coll] = 1
        coll_cnt = {}

        # Checking if is there any collection which don't exists.
        if (len(set(colls) - set(["test","test.mike"])) == 0 or
            len(set(colls) - set(["test","test.mike","system.indexes"])) == 0):
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        colls = db.list_collections(filter={"name": {"$regex": "^test$"}})
        self.assertEqual(1, len(list(colls)))

        colls = db.list_collections(filter={"name": {"$regex": "^test.mike$"}})
        self.assertEqual(1, len(list(colls)))

        db.drop_collection("test")

        db.create_collection("test", capped=True, size=4096)
        results = db.list_collections(filter={'options.capped': True})
        colls = [result["name"] for result in results]

        # Checking only capped collections are present
        self.assertTrue("test" in colls)
        self.assertFalse("test.mike" in colls)

        # No collection containing a '$'.
        for coll in colls:
            self.assertTrue("$" not in coll)

        # Duplicate check.
        coll_cnt = {}
        for coll in colls:
            try:
                # Found duplicate.
                coll_cnt[coll] += 1
                self.assertTrue(False)
            except KeyError:
                coll_cnt[coll] = 1
        coll_cnt = {}

        # Checking if is there any collection which don't exists.
        if (len(set(colls) - set(["test"])) == 0 or
            len(set(colls) - set(["test","system.indexes"])) == 0):
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        self.client.drop_database("pymongo_test")
Esempio n. 25
0
class TestCursor(unittest.TestCase):
    def setUp(self):
        self.client = get_client()
        self.db = Database(self.client, "pymongo_test")

    def tearDown(self):
        self.db = None

    def test_max_time_ms(self):
        if not version.at_least(self.db.connection, (2, 5, 3, -1)):
            raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")

        max_time_ms_response = {
            '$err': 'operation exceeded time limit',
            'code': 50
        }
        bson_response = BSON.encode(max_time_ms_response)
        response_flags = pack("<i", 2)
        cursor_id = pack("<q", 0)
        starting_from = pack("<i", 0)
        number_returned = pack("<i", 1)
        op_reply = (response_flags + cursor_id + starting_from +
                    number_returned + bson_response)
        self.assertRaises(ExecutionTimeout, _unpack_response, op_reply)

        command_response = {
            'ok': 0,
            'errmsg': 'operation exceeded time limit',
            'code': 50
        }
        self.assertRaises(ExecutionTimeout, _check_command_response,
                          command_response, None)

        db = self.db
        db.pymongo_test.drop()
        coll = db.pymongo_test
        self.assertRaises(TypeError, coll.find().max_time_ms, 'foo')
        coll.insert({"amalia": 1})
        coll.insert({"amalia": 2})

        coll.find().max_time_ms(None)
        coll.find().max_time_ms(1L)

        cursor = coll.find().max_time_ms(999)
        self.assertEqual(999, cursor._Cursor__max_time_ms)
        cursor = coll.find().max_time_ms(10).max_time_ms(1000)
        self.assertEqual(1000, cursor._Cursor__max_time_ms)

        cursor = coll.find().max_time_ms(999)
        c2 = cursor.clone()
        self.assertEqual(999, c2._Cursor__max_time_ms)
        self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec())
        self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec())

        self.assertTrue(coll.find_one(max_time_ms=1000))

        reducer = Code("""function(obj, prev){prev.count++;}""")
        coll.group(key={"amalia": 1},
                   condition={},
                   initial={"count": 0},
                   reduce=reducer,
                   maxTimeMS=1000)

        if "enableTestCommands=1" in get_command_line(self.client)["argv"]:
            self.client.admin.command("configureFailPoint",
                                      "maxTimeAlwaysTimeOut",
                                      mode="alwaysOn")
            self.assertRaises(ExecutionTimeout, coll.find_one, max_time_ms=1)
            self.client.admin.command("configureFailPoint",
                                      "maxTimeAlwaysTimeOut",
                                      mode="off")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assertTrue("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(
            OperationFailure,
            db.test.find({
                "num": 17,
                "foo": 17
            }).hint([("num", ASCENDING)]).explain)
        self.assertRaises(
            OperationFailure,
            db.test.find({
                "num": 17,
                "foo": 17
            }).hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(
            db.test.find({}).hint(spec).explain()["cursor"],
            "BtreeCursor %s" % index)
        self.assertEqual(
            db.test.find({}).hint(spec).hint(None).explain()["cursor"],
            "BasicCursor")
        self.assertRaises(
            OperationFailure,
            db.test.find({
                "num": 17,
                "foo": 17
            }).hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)
        self.assertTrue(db.test.find().limit(5L))

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_max(self):
        db = self.db
        db.test.drop()
        db.test.ensure_index([("j", ASCENDING)])

        for j in range(10):
            db.test.insert({"j": j, "k": j})

        cursor = db.test.find().max([("j", 3)])
        self.assertEqual(len(list(cursor)), 3)

        # Tuple.
        cursor = db.test.find().max((("j", 3), ))
        self.assertEqual(len(list(cursor)), 3)

        # Compound index.
        db.test.ensure_index([("j", ASCENDING), ("k", ASCENDING)])
        cursor = db.test.find().max([("j", 3), ("k", 3)])
        self.assertEqual(len(list(cursor)), 3)

        # Wrong order.
        cursor = db.test.find().max([("k", 3), ("j", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        # No such index.
        cursor = db.test.find().max([("k", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        self.assertRaises(TypeError, db.test.find().max, 10)
        self.assertRaises(TypeError, db.test.find().max, {"j": 10})

    def test_min(self):
        db = self.db
        db.test.drop()
        db.test.ensure_index([("j", ASCENDING)])

        for j in range(10):
            db.test.insert({"j": j, "k": j})

        cursor = db.test.find().min([("j", 3)])
        self.assertEqual(len(list(cursor)), 7)

        # Tuple.
        cursor = db.test.find().min((("j", 3), ))
        self.assertEqual(len(list(cursor)), 7)

        # Compound index.
        db.test.ensure_index([("j", ASCENDING), ("k", ASCENDING)])
        cursor = db.test.find().min([("j", 3), ("k", 3)])
        self.assertEqual(len(list(cursor)), 7)

        # Wrong order.
        cursor = db.test.find().min([("k", 3), ("j", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        # No such index.
        cursor = db.test.find().min([("k", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        self.assertRaises(TypeError, db.test.find().min, 10)
        self.assertRaises(TypeError, db.test.find().min, {"j": 10})

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        self.assertTrue(db.test.find().batch_size(5L))
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_limit_and_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(500):
            db.test.save({"x": x})

        curs = db.test.find().limit(0).batch_size(10)
        curs.next()
        self.assertEqual(10, curs._Cursor__retrieved)

        curs = db.test.find().limit(-2).batch_size(0)
        curs.next()
        self.assertEqual(2, curs._Cursor__retrieved)

        curs = db.test.find().limit(-4).batch_size(5)
        curs.next()
        self.assertEqual(4, curs._Cursor__retrieved)

        curs = db.test.find().limit(50).batch_size(500)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        curs = db.test.find().batch_size(500)
        curs.next()
        self.assertEqual(500, curs._Cursor__retrieved)

        curs = db.test.find().limit(50)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        # these two might be shaky, as the default
        # is set by the server. as of 2.0.0-rc0, 101
        # or 1MB (whichever is smaller) is default
        # for queries without ntoreturn
        curs = db.test.find()
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

        curs = db.test.find().limit(0).batch_size(0)
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)
        self.assertRaises(ValueError, db.test.find().skip, -5)
        self.assertTrue(db.test.find().skip(5L))

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError,
                          db.test.find().sort, [("hello", DESCENDING)],
                          DESCENDING)

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [
            i["x"]
            for i in db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)
        ]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [
            (i["a"], i["b"])
            for i in db.test.find().sort([("b", DESCENDING), ("a", ASCENDING)])
        ]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assertTrue(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(
            3, len(list(db.test.find().where(Code('this.x < i', {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in db.test.find().where('this.x < 3')])
        self.assertEqual(
            [], [a["x"] for a in db.test.find({
                "x": 5
            }).where('this.x < 3')])
        self.assertEqual(
            [5], [a["x"] for a in db.test.find({
                "x": 5
            }).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

        class MyClass(dict):
            pass

        cursor = self.db.test.find(as_class=MyClass)
        for e in cursor:
            self.assertEqual(type(MyClass()), type(e))
        cursor = self.db.test.find(as_class=MyClass)
        self.assertEqual(type(MyClass()), type(cursor[0]))

        # Just test attributes
        cursor = self.db.test.find({
            "x": re.compile("^hello.*")
        },
                                   skip=1,
                                   timeout=False,
                                   snapshot=True,
                                   tailable=True,
                                   as_class=MyClass,
                                   slave_okay=True,
                                   await_data=True,
                                   partial=True,
                                   manipulate=False,
                                   compile_re=False,
                                   fields={
                                       '_id': False
                                   }).limit(2)
        cursor.min([('a', 1)]).max([('b', 3)])
        cursor.add_option(128)
        cursor.comment('hi!')

        cursor2 = cursor.clone()
        self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
        self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
        self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
        self.assertEqual(type(cursor._Cursor__as_class),
                         type(cursor2._Cursor__as_class))
        self.assertEqual(cursor._Cursor__slave_okay,
                         cursor2._Cursor__slave_okay)
        self.assertEqual(cursor._Cursor__manipulate,
                         cursor2._Cursor__manipulate)
        self.assertEqual(cursor._Cursor__compile_re,
                         cursor2._Cursor__compile_re)
        self.assertEqual(cursor._Cursor__query_flags,
                         cursor2._Cursor__query_flags)
        self.assertEqual(cursor._Cursor__comment, cursor2._Cursor__comment)
        self.assertEqual(cursor._Cursor__min, cursor2._Cursor__min)
        self.assertEqual(cursor._Cursor__max, cursor2._Cursor__max)

        # Shallow copies can so can mutate
        cursor2 = copy.copy(cursor)
        cursor2._Cursor__fields['cursor2'] = False
        self.assertTrue('cursor2' in cursor._Cursor__fields)

        # Deepcopies and shouldn't mutate
        cursor3 = copy.deepcopy(cursor)
        cursor3._Cursor__fields['cursor3'] = False
        self.assertFalse('cursor3' in cursor._Cursor__fields)

        cursor4 = cursor.clone()
        cursor4._Cursor__fields['cursor4'] = False
        self.assertFalse('cursor4' in cursor._Cursor__fields)

        # Test memo when deepcopying queries
        query = {"hello": "world"}
        query["reflexive"] = query
        cursor = self.db.test.find(query)

        cursor2 = copy.deepcopy(cursor)

        self.assertNotEqual(id(cursor._Cursor__spec),
                            id(cursor2._Cursor__spec))
        self.assertEqual(id(cursor2._Cursor__spec['reflexive']),
                         id(cursor2._Cursor__spec))
        self.assertEqual(len(cursor2._Cursor__spec), 2)

        # Ensure hints are cloned as the correct type
        cursor = self.db.test.find().hint([('z', 1), ("a", 1)])
        cursor2 = copy.deepcopy(cursor)
        self.assertTrue(isinstance(cursor2._Cursor__hint, SON))
        self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint)

    def test_deepcopy_cursor_littered_with_regexes(self):

        cursor = self.db.test.find({
            "x": re.compile("^hmmm.*"),
            "y": [re.compile("^hmm.*")],
            "z": {
                "a": [re.compile("^hm.*")]
            },
            re.compile("^key.*"): {
                "a": [re.compile("^hm.*")]
            }
        })

        cursor2 = copy.deepcopy(cursor)
        self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)

    def test_add_remove_option(self):
        cursor = self.db.test.find()
        self.assertEqual(0, cursor._Cursor__query_options())
        cursor.add_option(2)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(32)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(128)
        cursor2 = self.db.test.find(tailable=True,
                                    await_data=True).add_option(128)
        self.assertEqual(162, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(162, cursor._Cursor__query_options())
        cursor.add_option(128)
        self.assertEqual(162, cursor._Cursor__query_options())

        cursor.remove_option(128)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(2, cursor._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

        # Slave OK
        cursor = self.db.test.find(slave_okay=True)
        self.assertEqual(4, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(4)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        self.assertTrue(cursor._Cursor__slave_okay)
        cursor.remove_option(4)
        self.assertEqual(0, cursor._Cursor__query_options())
        self.assertFalse(cursor._Cursor__slave_okay)

        # Timeout
        cursor = self.db.test.find(timeout=False)
        self.assertEqual(16, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(16)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(16)
        self.assertEqual(0, cursor._Cursor__query_options())

        # Tailable / Await data
        cursor = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(34)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

        # Exhaust - which mongos doesn't support
        if not is_mongos(self.db.connection):
            cursor = self.db.test.find(exhaust=True)
            self.assertEqual(64, cursor._Cursor__query_options())
            cursor2 = self.db.test.find().add_option(64)
            self.assertEqual(cursor._Cursor__query_options(),
                             cursor2._Cursor__query_options())
            self.assertTrue(cursor._Cursor__exhaust)
            cursor.remove_option(64)
            self.assertEqual(0, cursor._Cursor__query_options())
            self.assertFalse(cursor._Cursor__exhaust)

        # Partial
        cursor = self.db.test.find(partial=True)
        self.assertEqual(128, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(128)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(128)
        self.assertEqual(0, cursor._Cursor__query_options())

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in zip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in zip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in zip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in zip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in zip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in zip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(
            80, len(list(self.db.test.find()[40:45].limit(0).skip(20))))
        for a, b in zip(count(20),
                        self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(
            80, len(list(self.db.test.find().limit(10).skip(40)[20:])))
        for a, b in zip(count(20),
                        self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(
            80, len(list(self.db.test.find()[10:10].limit(0).skip(20))))

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4")

        self.assertRaises(TypeError, self.db.test.find().count, "foo")

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_get_more(self):
        db = self.db
        db.drop_collection("test")
        db.test.insert([{'i': i} for i in range(10)])
        self.assertEqual(10, len(list(db.test.find().batch_size(5))))

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000)

        cursor = db.test.find(tailable=True)

        db.test.insert({"x": 1})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(1, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 2})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(2, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 3})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(3, doc["x"])
        self.assertEqual(1, count)

        self.assertEqual(3, db.test.count())
        db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest("distinct with query requires MongoDB >= 1.1.3")

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest("maxScan requires MongoDB >= 1.5.1")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(
            50, len(list(self.db.test.find().max_scan(90).max_scan(50))))

    def test_with_statement(self):
        if sys.version_info < (2, 6):
            raise SkipTest("With statement requires Python >= 2.6")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        c1 = self.db.test.find()
        exec """
with self.db.test.find() as c2:
    self.assertTrue(c2.alive)
self.assertFalse(c2.alive)

with self.db.test.find() as c2:
    self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
        self.assertTrue(c1.alive)

    def test_comment(self):
        def run_with_profiling(func):
            self.db.set_profiling_level(OFF)
            self.db.system.profile.drop()
            self.db.set_profiling_level(ALL)
            func()
            self.db.set_profiling_level(OFF)

        def find():
            list(self.db.test.find().comment('foo'))
            op = self.db.system.profile.find({
                'ns': 'pymongo_test.test',
                'op': 'query',
                'query.$comment': 'foo'
            })
            self.assertEqual(op.count(), 1)

        run_with_profiling(find)

        def count():
            self.db.test.find().comment('foo').count()
            op = self.db.system.profile.find({
                'ns': 'pymongo_test.$cmd',
                'op': 'command',
                'command.count': 'test',
                'command.$comment': 'foo'
            })
            self.assertEqual(op.count(), 1)

        run_with_profiling(count)

        def distinct():
            self.db.test.find().comment('foo').distinct('type')
            op = self.db.system.profile.find({
                'ns': 'pymongo_test.$cmd',
                'op': 'command',
                'command.distinct': 'test',
                'command.$comment': 'foo'
            })
            self.assertEqual(op.count(), 1)

        run_with_profiling(distinct)

        self.db.test.insert([{}, {}])
        cursor = self.db.test.find()
        cursor.next()
        self.assertRaises(InvalidOperation, cursor.comment, 'hello')

        self.db.system.profile.drop()
Esempio n. 26
0
class TestCursor(unittest.TestCase):

    def setUp(self):
        self.db = Database(get_connection(), "pymongo_test")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assertTrue("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("num", ASCENDING)]).explain)
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(db.test.find({}).hint(spec).explain()["cursor"],
                         "BtreeCursor %s" % index)
        self.assertEqual(db.test.find({}).hint(spec).hint(None)
                         .explain()["cursor"],
                         "BasicCursor")
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_limit_and_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(500):
            db.test.save({"x": x})

        curs = db.test.find().limit(0).batch_size(10)
        curs.next()
        self.assertEqual(10, curs._Cursor__retrieved)

        curs = db.test.find().limit(-2).batch_size(0)
        curs.next()
        self.assertEqual(2, curs._Cursor__retrieved)

        curs = db.test.find().limit(-4).batch_size(5)
        curs.next()
        self.assertEqual(4, curs._Cursor__retrieved)

        curs = db.test.find().limit(50).batch_size(500)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        curs = db.test.find().batch_size(500)
        curs.next()
        self.assertEqual(500, curs._Cursor__retrieved)

        curs = db.test.find().limit(50)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        # these two might be shaky, as the default
        # is set by the server. as of 2.0.0-rc0, 101
        # or 1MB (whichever is smaller) is default
        # for queries without ntoreturn
        curs = db.test.find()
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

        curs = db.test.find().limit(0).batch_size(0)
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError, db.test.find().sort,
                          [("hello", DESCENDING)], DESCENDING)
        self.assertRaises(TypeError, db.test.find().sort, "hello", "world")

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in
                db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [(i["a"], i["b"]) for i in
                  db.test.find().sort([("b", DESCENDING),
                                       ("a", ASCENDING)])]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assertTrue(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(3, len(list(db.test.find().where(Code('this.x < i',
                                                               {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in
                          db.test.find().where('this.x < 3')])
        self.assertEqual([],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x < 3')])
        self.assertEqual([5],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

        class MyClass(dict):
            pass

        cursor = self.db.test.find(as_class=MyClass)
        for e in cursor:
            self.assertEqual(type(MyClass()), type(e))
        cursor = self.db.test.find(as_class=MyClass)
        self.assertEqual(type(MyClass()), type(cursor[0]))

        # Just test attributes
        cursor = self.db.test.find(skip=1,
                                   timeout=False,
                                   snapshot=True,
                                   tailable=True,
                                   as_class=MyClass,
                                   slave_okay=True,
                                   await_data=True,
                                   partial=True,
                                   manipulate=False,
                                   fields={'_id': False}).limit(2)
        cursor.add_option(64)

        cursor2 = cursor.clone()
        self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
        self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
        self.assertEqual(cursor._Cursor__timeout, cursor2._Cursor__timeout)
        self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
        self.assertEqual(cursor._Cursor__tailable, cursor2._Cursor__tailable)
        self.assertEqual(type(cursor._Cursor__as_class),
                         type(cursor2._Cursor__as_class))
        self.assertEqual(cursor._Cursor__slave_okay,
                         cursor2._Cursor__slave_okay)
        self.assertEqual(cursor._Cursor__await_data,
                         cursor2._Cursor__await_data)
        self.assertEqual(cursor._Cursor__partial, cursor2._Cursor__partial)
        self.assertEqual(cursor._Cursor__manipulate,
                         cursor2._Cursor__manipulate)
        self.assertEqual(cursor._Cursor__query_flags,
                         cursor2._Cursor__query_flags)

        # Shallow copies can so can mutate
        cursor2 = copy.copy(cursor)
        cursor2._Cursor__fields['cursor2'] = False
        self.assertTrue('cursor2' in cursor._Cursor__fields)

        # Deepcopies and shouldn't mutate
        cursor3 = copy.deepcopy(cursor)
        cursor3._Cursor__fields['cursor3'] = False
        self.assertFalse('cursor3' in cursor._Cursor__fields)

        cursor4 = cursor.clone()
        cursor4._Cursor__fields['cursor4'] = False
        self.assertFalse('cursor4' in cursor._Cursor__fields)

    def test_add_remove_option(self):
        cursor = self.db.test.find()
        self.assertEqual(0, cursor._Cursor__query_options())
        cursor.add_option(2)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(32)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(128)
        cursor2 = self.db.test.find(tailable=True,
                                    await_data=True).add_option(128)
        self.assertEqual(162, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(162, cursor._Cursor__query_options())
        cursor.add_option(128)
        self.assertEqual(162, cursor._Cursor__query_options())

        cursor.remove_option(128)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(2, cursor._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in zip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in zip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in zip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in zip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in zip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in zip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find()[40:45].limit(0).skip(20))
                            )
                        )
        for a, b in zip(count(20),
                         self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find().limit(10).skip(40)[20:]))
                        )
        for a, b in zip(count(20),
                         self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(80,
                         len(list(self.db.test.find()[10:10].limit(0).skip(20))
                            )
                        )

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4")

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_get_more(self):
        db = self.db
        db.drop_collection("test")
        db.test.insert([{'i': i} for i in range(10)])
        self.assertEqual(10, len(list(db.test.find().batch_size(5))))

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000)

        cursor = db.test.find(tailable=True)

        db.test.insert({"x": 1}, safe=True)
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(1, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 2}, safe=True)
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(2, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 3}, safe=True)
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(3, doc["x"])
        self.assertEqual(1, count)

        self.assertEqual(3, db.test.count())
        db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest("distinct with query requires MongoDB >= 1.1.3")

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest("maxScan requires MongoDB >= 1.5.1")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(50, len(list(self.db.test.find()
                                      .max_scan(90).max_scan(50))))

    def test_with_statement(self):
        if sys.version_info < (2, 6):
            raise SkipTest("With statement requires Python >= 2.6")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        c1 = self.db.test.find()
        exec """
with self.db.test.find() as c2:
    self.assertTrue(c2.alive)
self.assertFalse(c2.alive)

with self.db.test.find() as c2:
    self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
        self.assertTrue(c1.alive)
class TestCursor(unittest.TestCase):

    def setUp(self):
        self.db = Database(get_client(), "pymongo_test")

    def test_explain(self):
        a = self.db.test.find()
        b = a.explain()
        for _ in a:
            break
        c = a.explain()
        del b["millis"]
        b.pop("oldPlan", None)
        del c["millis"]
        c.pop("oldPlan", None)
        self.assertEqual(b, c)
        self.assertTrue("cursor" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("num", ASCENDING)]).explain)
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        index = db.test.create_index("num")

        spec = [("num", ASCENDING)]
        self.assertEqual(db.test.find({}).explain()["cursor"], "BasicCursor")
        self.assertEqual(db.test.find({}).hint(spec).explain()["cursor"],
                         "BtreeCursor %s" % index)
        self.assertEqual(db.test.find({}).hint(spec).hint(None)
                         .explain()["cursor"],
                         "BasicCursor")
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

        self.assertRaises(TypeError, db.test.find().hint, index)

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_limit_and_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(500):
            db.test.save({"x": x})

        curs = db.test.find().limit(0).batch_size(10)
        curs.next()
        self.assertEqual(10, curs._Cursor__retrieved)

        curs = db.test.find().limit(-2).batch_size(0)
        curs.next()
        self.assertEqual(2, curs._Cursor__retrieved)

        curs = db.test.find().limit(-4).batch_size(5)
        curs.next()
        self.assertEqual(4, curs._Cursor__retrieved)

        curs = db.test.find().limit(50).batch_size(500)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        curs = db.test.find().batch_size(500)
        curs.next()
        self.assertEqual(500, curs._Cursor__retrieved)

        curs = db.test.find().limit(50)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        # these two might be shaky, as the default
        # is set by the server. as of 2.0.0-rc0, 101
        # or 1MB (whichever is smaller) is default
        # for queries without ntoreturn
        curs = db.test.find()
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

        curs = db.test.find().limit(0).batch_size(0)
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError, db.test.find().sort,
                          [("hello", DESCENDING)], DESCENDING)

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in
                db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [(i["a"], i["b"]) for i in
                  db.test.find().sort([("b", DESCENDING),
                                       ("a", ASCENDING)])]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assertTrue(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(3, len(list(db.test.find().where(Code('this.x < i',
                                                               {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in
                          db.test.find().where('this.x < 3')])
        self.assertEqual([],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x < 3')])
        self.assertEqual([5],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

        class MyClass(dict):
            pass

        cursor = self.db.test.find(as_class=MyClass)
        for e in cursor:
            self.assertEqual(type(MyClass()), type(e))
        cursor = self.db.test.find(as_class=MyClass)
        self.assertEqual(type(MyClass()), type(cursor[0]))

        # Just test attributes
        cursor = self.db.test.find({"x": re.compile("^hello.*")},
                                   skip=1,
                                   timeout=False,
                                   snapshot=True,
                                   tailable=True,
                                   as_class=MyClass,
                                   slave_okay=True,
                                   await_data=True,
                                   partial=True,
                                   manipulate=False,
                                   fields={'_id': False}).limit(2)
        cursor.add_option(128)

        cursor2 = cursor.clone()
        self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
        self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
        self.assertEqual(cursor._Cursor__timeout, cursor2._Cursor__timeout)
        self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
        self.assertEqual(cursor._Cursor__tailable, cursor2._Cursor__tailable)
        self.assertEqual(type(cursor._Cursor__as_class),
                         type(cursor2._Cursor__as_class))
        self.assertEqual(cursor._Cursor__slave_okay,
                         cursor2._Cursor__slave_okay)
        self.assertEqual(cursor._Cursor__await_data,
                         cursor2._Cursor__await_data)
        self.assertEqual(cursor._Cursor__partial, cursor2._Cursor__partial)
        self.assertEqual(cursor._Cursor__manipulate,
                         cursor2._Cursor__manipulate)
        self.assertEqual(cursor._Cursor__query_flags,
                         cursor2._Cursor__query_flags)

        # Shallow copies can so can mutate
        cursor2 = copy.copy(cursor)
        cursor2._Cursor__fields['cursor2'] = False
        self.assertTrue('cursor2' in cursor._Cursor__fields)

        # Deepcopies and shouldn't mutate
        cursor3 = copy.deepcopy(cursor)
        cursor3._Cursor__fields['cursor3'] = False
        self.assertFalse('cursor3' in cursor._Cursor__fields)

        cursor4 = cursor.clone()
        cursor4._Cursor__fields['cursor4'] = False
        self.assertFalse('cursor4' in cursor._Cursor__fields)

        # Test memo when deepcopying queries
        query = {"hello": "world"}
        query["reflexive"] = query
        cursor = self.db.test.find(query)

        cursor2 = copy.deepcopy(cursor)

        self.assertNotEqual(id(cursor._Cursor__spec),
                            id(cursor2._Cursor__spec))
        self.assertEqual(id(cursor2._Cursor__spec['reflexive']),
                         id(cursor2._Cursor__spec))
        self.assertEqual(len(cursor2._Cursor__spec), 2)

        # Ensure hints are cloned as the correct type
        cursor = self.db.test.find().hint([('z', 1), ("a", 1)])
        cursor2 = copy.deepcopy(cursor)
        self.assertTrue(isinstance(cursor2._Cursor__hint, SON))
        self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint)

    def test_deepcopy_cursor_littered_with_regexes(self):

        cursor = self.db.test.find({"x": re.compile("^hmmm.*"),
                                    "y": [re.compile("^hmm.*")],
                                    "z": {"a": [re.compile("^hm.*")]},
                                    re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})

        cursor2 = copy.deepcopy(cursor)
        self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)

    def test_add_remove_option(self):
        cursor = self.db.test.find()
        self.assertEqual(0, cursor._Cursor__query_options())
        cursor.add_option(2)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(32)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(128)
        cursor2 = self.db.test.find(tailable=True,
                                    await_data=True).add_option(128)
        self.assertEqual(162, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(162, cursor._Cursor__query_options())
        cursor.add_option(128)
        self.assertEqual(162, cursor._Cursor__query_options())

        cursor.remove_option(128)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(2, cursor._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in zip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in zip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in zip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in zip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in zip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in zip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find()[40:45].limit(0).skip(20))
                            )
                        )
        for a, b in zip(count(20),
                         self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find().limit(10).skip(40)[20:]))
                        )
        for a, b in zip(count(20),
                         self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(80,
                         len(list(self.db.test.find()[10:10].limit(0).skip(20))
                            )
                        )

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4")

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_get_more(self):
        db = self.db
        db.drop_collection("test")
        db.test.insert([{'i': i} for i in range(10)])
        self.assertEqual(10, len(list(db.test.find().batch_size(5))))

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000)

        cursor = db.test.find(tailable=True)

        db.test.insert({"x": 1})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(1, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 2})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(2, doc["x"])
        self.assertEqual(1, count)

        db.test.insert({"x": 3})
        count = 0
        for doc in cursor:
            count += 1
            self.assertEqual(3, doc["x"])
        self.assertEqual(1, count)

        self.assertEqual(3, db.test.count())
        db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest("distinct with query requires MongoDB >= 1.1.3")

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest("maxScan requires MongoDB >= 1.5.1")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(50, len(list(self.db.test.find()
                                      .max_scan(90).max_scan(50))))

    def test_with_statement(self):
        if sys.version_info < (2, 6):
            raise SkipTest("With statement requires Python >= 2.6")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        c1 = self.db.test.find()
        exec """
with self.db.test.find() as c2:
    self.assertTrue(c2.alive)
self.assertFalse(c2.alive)

with self.db.test.find() as c2:
    self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
        self.assertTrue(c1.alive)
Esempio n. 28
0
 def test1(self):
     db = Database(self._get_connection(), "pymongo_test")
     test = db.create_collection("test_1_4")
     test.save({"hello": u"world"})
     test.rename("test_1_new")
     db.drop_collection("test_1_new")
Esempio n. 29
0
class TestCursor(unittest.TestCase):

    def setUp(self):
        self.client = get_client()
        self.db = Database(self.client, "pymongo_test")

    def tearDown(self):
        self.db = None

    def test_max_time_ms(self):
        if not version.at_least(self.db.connection, (2, 5, 3, -1)):
            raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")

        db = self.db
        db.pymongo_test.drop()
        coll = db.pymongo_test
        self.assertRaises(TypeError, coll.find().max_time_ms, 'foo')
        coll.insert({"amalia": 1})
        coll.insert({"amalia": 2})

        coll.find().max_time_ms(None)
        coll.find().max_time_ms(1L)

        cursor = coll.find().max_time_ms(999)
        self.assertEqual(999, cursor._Cursor__max_time_ms)
        cursor = coll.find().max_time_ms(10).max_time_ms(1000)
        self.assertEqual(1000, cursor._Cursor__max_time_ms)

        cursor = coll.find().max_time_ms(999)
        c2 = cursor.clone()
        self.assertEqual(999, c2._Cursor__max_time_ms)
        self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec())
        self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec())

        self.assertTrue(coll.find_one(max_time_ms=1000))

        if "enableTestCommands=1" in get_command_line(self.client)["argv"]:
            # Cursor parses server timeout error in response to initial query.
            self.client.admin.command("configureFailPoint",
                                      "maxTimeAlwaysTimeOut",
                                      mode="alwaysOn")
            try:
                cursor = coll.find().max_time_ms(1)
                try:
                    cursor.next()
                except ExecutionTimeout:
                    pass
                else:
                    self.fail("ExecutionTimeout not raised")
                self.assertRaises(ExecutionTimeout,
                                  coll.find_one, max_time_ms=1)
            finally:
                self.client.admin.command("configureFailPoint",
                                          "maxTimeAlwaysTimeOut",
                                          mode="off")

    def test_max_time_ms_getmore(self):
        # Test that Cursor handles server timeout error in response to getmore.
        if "enableTestCommands=1" not in get_command_line(self.client)["argv"]:
            raise SkipTest("Need test commands enabled")

        if not version.at_least(self.db.connection, (2, 5, 3, -1)):
            raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")

        coll = self.db.pymongo_test
        coll.insert({} for _ in range(200))
        cursor = coll.find().max_time_ms(100)

        # Send initial query before turning on failpoint.
        cursor.next()
        self.client.admin.command("configureFailPoint",
                                  "maxTimeAlwaysTimeOut",
                                  mode="alwaysOn")
        try:
            try:
                # Iterate up to first getmore.
                list(cursor)
            except ExecutionTimeout:
                pass
            else:
                self.fail("ExecutionTimeout not raised")
        finally:
            self.client.admin.command("configureFailPoint",
                                      "maxTimeAlwaysTimeOut",
                                      mode="off")

    def test_explain(self):
        a = self.db.test.find()
        a.explain()
        for _ in a:
            break
        b = a.explain()
        # "cursor" pre MongoDB 2.7.6, "executionStats" post
        self.assertTrue("cursor" in b or "executionStats" in b)

    def test_hint(self):
        db = self.db
        self.assertRaises(TypeError, db.test.find().hint, 5.5)
        db.test.drop()

        for i in range(100):
            db.test.insert({"num": i, "foo": i})

        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("num", ASCENDING)]).explain)
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        spec = [("num", DESCENDING)]
        index = db.test.create_index(spec)

        first = db.test.find().next()
        self.assertEqual(0, first.get('num'))
        first = db.test.find().hint(spec).next()
        self.assertEqual(99, first.get('num'))
        self.assertRaises(OperationFailure,
                          db.test.find({"num": 17, "foo": 17})
                          .hint([("foo", ASCENDING)]).explain)

        a = db.test.find({"num": 17})
        a.hint(spec)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.hint, spec)

    def test_hint_by_name(self):
        db = self.db
        db.test.drop()

        for i in range(100):
            db.test.insert({'i': i})

        db.test.create_index([('i', DESCENDING)], name='fooindex')
        first = db.test.find().next()
        self.assertEqual(0, first.get('i'))
        first = db.test.find().hint('fooindex').next()
        self.assertEqual(99, first.get('i'))

    def test_limit(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().limit, None)
        self.assertRaises(TypeError, db.test.find().limit, "hello")
        self.assertRaises(TypeError, db.test.find().limit, 5.5)
        self.assertTrue(db.test.find().limit(5L))

        db.test.drop()
        for i in range(100):
            db.test.save({"x": i})

        count = 0
        for _ in db.test.find():
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(20):
            count += 1
        self.assertEqual(count, 20)

        count = 0
        for _ in db.test.find().limit(99):
            count += 1
        self.assertEqual(count, 99)

        count = 0
        for _ in db.test.find().limit(1):
            count += 1
        self.assertEqual(count, 1)

        count = 0
        for _ in db.test.find().limit(0):
            count += 1
        self.assertEqual(count, 100)

        count = 0
        for _ in db.test.find().limit(0).limit(50).limit(10):
            count += 1
        self.assertEqual(count, 10)

        a = db.test.find()
        a.limit(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.limit, 5)

    def test_max(self):
        db = self.db
        db.test.drop()
        db.test.ensure_index([("j", ASCENDING)])

        for j in range(10):
            db.test.insert({"j": j, "k": j})

        cursor = db.test.find().max([("j", 3)])
        self.assertEqual(len(list(cursor)), 3)

        # Tuple.
        cursor = db.test.find().max((("j", 3), ))
        self.assertEqual(len(list(cursor)), 3)

        # Compound index.
        db.test.ensure_index([("j", ASCENDING), ("k", ASCENDING)])
        cursor = db.test.find().max([("j", 3), ("k", 3)])
        self.assertEqual(len(list(cursor)), 3)

        # Wrong order.
        cursor = db.test.find().max([("k", 3), ("j", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        # No such index.
        cursor = db.test.find().max([("k", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        self.assertRaises(TypeError, db.test.find().max, 10)
        self.assertRaises(TypeError, db.test.find().max, {"j": 10})

    def test_min(self):
        db = self.db
        db.test.drop()
        db.test.ensure_index([("j", ASCENDING)])

        for j in range(10):
            db.test.insert({"j": j, "k": j})

        cursor = db.test.find().min([("j", 3)])
        self.assertEqual(len(list(cursor)), 7)

        # Tuple.
        cursor = db.test.find().min((("j", 3), ))
        self.assertEqual(len(list(cursor)), 7)

        # Compound index.
        db.test.ensure_index([("j", ASCENDING), ("k", ASCENDING)])
        cursor = db.test.find().min([("j", 3), ("k", 3)])
        self.assertEqual(len(list(cursor)), 7)

        # Wrong order.
        cursor = db.test.find().min([("k", 3), ("j", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        # No such index.
        cursor = db.test.find().min([("k", 3)])
        self.assertRaises(OperationFailure, list, cursor)

        self.assertRaises(TypeError, db.test.find().min, 10)
        self.assertRaises(TypeError, db.test.find().min, {"j": 10})

    def test_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(200):
            db.test.save({"x": x})

        self.assertRaises(TypeError, db.test.find().batch_size, None)
        self.assertRaises(TypeError, db.test.find().batch_size, "hello")
        self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
        self.assertRaises(ValueError, db.test.find().batch_size, -1)
        self.assertTrue(db.test.find().batch_size(5L))
        a = db.test.find()
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.batch_size, 5)

        def cursor_count(cursor, expected_count):
            count = 0
            for _ in cursor:
                count += 1
            self.assertEqual(expected_count, count)

        cursor_count(db.test.find().batch_size(0), 200)
        cursor_count(db.test.find().batch_size(1), 200)
        cursor_count(db.test.find().batch_size(2), 200)
        cursor_count(db.test.find().batch_size(5), 200)
        cursor_count(db.test.find().batch_size(100), 200)
        cursor_count(db.test.find().batch_size(500), 200)

        cursor_count(db.test.find().batch_size(0).limit(1), 1)
        cursor_count(db.test.find().batch_size(1).limit(1), 1)
        cursor_count(db.test.find().batch_size(2).limit(1), 1)
        cursor_count(db.test.find().batch_size(5).limit(1), 1)
        cursor_count(db.test.find().batch_size(100).limit(1), 1)
        cursor_count(db.test.find().batch_size(500).limit(1), 1)

        cursor_count(db.test.find().batch_size(0).limit(10), 10)
        cursor_count(db.test.find().batch_size(1).limit(10), 10)
        cursor_count(db.test.find().batch_size(2).limit(10), 10)
        cursor_count(db.test.find().batch_size(5).limit(10), 10)
        cursor_count(db.test.find().batch_size(100).limit(10), 10)
        cursor_count(db.test.find().batch_size(500).limit(10), 10)

    def test_limit_and_batch_size(self):
        db = self.db
        db.test.drop()
        for x in range(500):
            db.test.save({"x": x})

        curs = db.test.find().limit(0).batch_size(10)
        curs.next()
        self.assertEqual(10, curs._Cursor__retrieved)

        curs = db.test.find().limit(-2).batch_size(0)
        curs.next()
        self.assertEqual(2, curs._Cursor__retrieved)

        curs = db.test.find().limit(-4).batch_size(5)
        curs.next()
        self.assertEqual(4, curs._Cursor__retrieved)

        curs = db.test.find().limit(50).batch_size(500)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        curs = db.test.find().batch_size(500)
        curs.next()
        self.assertEqual(500, curs._Cursor__retrieved)

        curs = db.test.find().limit(50)
        curs.next()
        self.assertEqual(50, curs._Cursor__retrieved)

        # these two might be shaky, as the default
        # is set by the server. as of 2.0.0-rc0, 101
        # or 1MB (whichever is smaller) is default
        # for queries without ntoreturn
        curs = db.test.find()
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

        curs = db.test.find().limit(0).batch_size(0)
        curs.next()
        self.assertEqual(101, curs._Cursor__retrieved)

    def test_skip(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().skip, None)
        self.assertRaises(TypeError, db.test.find().skip, "hello")
        self.assertRaises(TypeError, db.test.find().skip, 5.5)
        self.assertRaises(ValueError, db.test.find().skip, -5)
        self.assertTrue(db.test.find().skip(5L))

        db.drop_collection("test")

        for i in range(100):
            db.test.save({"x": i})

        for i in db.test.find():
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(20):
            self.assertEqual(i["x"], 20)
            break

        for i in db.test.find().skip(99):
            self.assertEqual(i["x"], 99)
            break

        for i in db.test.find().skip(1):
            self.assertEqual(i["x"], 1)
            break

        for i in db.test.find().skip(0):
            self.assertEqual(i["x"], 0)
            break

        for i in db.test.find().skip(0).skip(50).skip(10):
            self.assertEqual(i["x"], 10)
            break

        for i in db.test.find().skip(1000):
            self.fail()

        a = db.test.find()
        a.skip(10)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.skip, 5)

    def test_sort(self):
        db = self.db

        self.assertRaises(TypeError, db.test.find().sort, 5)
        self.assertRaises(ValueError, db.test.find().sort, [])
        self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
        self.assertRaises(TypeError, db.test.find().sort,
                          [("hello", DESCENDING)], DESCENDING)

        db.test.drop()

        unsort = range(10)
        random.shuffle(unsort)

        for i in unsort:
            db.test.save({"x": i})

        asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort("x")]
        self.assertEqual(asc, range(10))
        asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
        self.assertEqual(asc, range(10))

        expect = range(10)
        expect.reverse()
        desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
        self.assertEqual(desc, expect)
        desc = [i["x"] for i in
                db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
        self.assertEqual(desc, expect)

        expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
        shuffled = list(expected)
        random.shuffle(shuffled)

        db.test.drop()
        for (a, b) in shuffled:
            db.test.save({"a": a, "b": b})

        result = [(i["a"], i["b"]) for i in
                  db.test.find().sort([("b", DESCENDING),
                                       ("a", ASCENDING)])]
        self.assertEqual(result, expected)

        a = db.test.find()
        a.sort("x", ASCENDING)
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)

    def test_count(self):
        db = self.db
        db.test.drop()

        self.assertEqual(0, db.test.find().count())

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(10, db.test.find().count())
        self.assertTrue(isinstance(db.test.find().count(), int))
        self.assertEqual(10, db.test.find().limit(5).count())
        self.assertEqual(10, db.test.find().skip(5).count())

        self.assertEqual(1, db.test.find({"x": 1}).count())
        self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())

        a = db.test.find()
        b = a.count()
        for _ in a:
            break
        self.assertEqual(b, a.count())

        self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())

    def test_count_with_hint(self):
        collection = self.db.test
        collection.drop()

        collection.save({'i': 1})
        collection.save({'i': 2})
        self.assertEqual(2, collection.find().count())

        collection.create_index([('i', 1)])

        self.assertEqual(1, collection.find({'i': 1}).hint("_id_").count())
        self.assertEqual(2, collection.find().hint("_id_").count())

        if version.at_least(self.client, (2, 6, 0)):
            # Count supports hint
            self.assertRaises(OperationFailure,
                              collection.find({'i': 1}).hint("BAD HINT").count)
        else:
            # Hint is ignored
            self.assertEqual(
                1, collection.find({'i': 1}).hint("BAD HINT").count())

        # Create a sparse index which should have no entries.
        collection.create_index([('x', 1)], sparse=True)

        if version.at_least(self.client, (2, 6, 0)):
            # Count supports hint
            self.assertEqual(0, collection.find({'i': 1}).hint("x_1").count())
        else:
            # Hint is ignored
            self.assertEqual(1, collection.find({'i': 1}).hint("x_1").count())

        self.assertEqual(2, collection.find().hint("x_1").count())

    def test_where(self):
        db = self.db
        db.test.drop()

        a = db.test.find()
        self.assertRaises(TypeError, a.where, 5)
        self.assertRaises(TypeError, a.where, None)
        self.assertRaises(TypeError, a.where, {})

        for i in range(10):
            db.test.save({"x": i})

        self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
        self.assertEqual(3,
                         len(list(db.test.find().where(Code('this.x < 3')))))
        self.assertEqual(3, len(list(db.test.find().where(Code('this.x < i',
                                                               {"i": 3})))))
        self.assertEqual(10, len(list(db.test.find())))

        self.assertEqual(3, db.test.find().where('this.x < 3').count())
        self.assertEqual(10, db.test.find().count())
        self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
        self.assertEqual([0, 1, 2],
                         [a["x"] for a in
                          db.test.find().where('this.x < 3')])
        self.assertEqual([],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x < 3')])
        self.assertEqual([5],
                         [a["x"] for a in
                          db.test.find({"x": 5}).where('this.x > 3')])

        cursor = db.test.find().where('this.x < 3').where('this.x > 7')
        self.assertEqual([8, 9], [a["x"] for a in cursor])

        a = db.test.find()
        b = a.where('this.x > 3')
        for _ in a:
            break
        self.assertRaises(InvalidOperation, a.where, 'this.x < 3')

    def test_rewind(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor.rewind()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertEqual(cursor, cursor.rewind())

    def test_clone(self):
        self.db.test.save({"x": 1})
        self.db.test.save({"x": 2})
        self.db.test.save({"x": 3})

        cursor = self.db.test.find().limit(2)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(0, count)

        cursor = cursor.clone()
        cursor2 = cursor.clone()
        count = 0
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)
        for _ in cursor2:
            count += 1
        self.assertEqual(4, count)

        cursor.rewind()
        count = 0
        for _ in cursor:
            break
        cursor = cursor.clone()
        for _ in cursor:
            count += 1
        self.assertEqual(2, count)

        self.assertNotEqual(cursor, cursor.clone())

        class MyClass(dict):
            pass

        cursor = self.db.test.find(as_class=MyClass)
        for e in cursor:
            self.assertEqual(type(MyClass()), type(e))
        cursor = self.db.test.find(as_class=MyClass)
        self.assertEqual(type(MyClass()), type(cursor[0]))

        # Just test attributes
        cursor = self.db.test.find({"x": re.compile("^hello.*")},
                                   skip=1,
                                   timeout=False,
                                   snapshot=True,
                                   tailable=True,
                                   as_class=MyClass,
                                   slave_okay=True,
                                   await_data=True,
                                   partial=True,
                                   manipulate=False,
                                   compile_re=False,
                                   fields={'_id': False}).limit(2)
        cursor.min([('a', 1)]).max([('b', 3)])
        cursor.add_option(128)
        cursor.comment('hi!')

        cursor2 = cursor.clone()
        self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
        self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
        self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
        self.assertEqual(type(cursor._Cursor__as_class),
                         type(cursor2._Cursor__as_class))
        self.assertEqual(cursor._Cursor__slave_okay,
                         cursor2._Cursor__slave_okay)
        self.assertEqual(cursor._Cursor__manipulate,
                         cursor2._Cursor__manipulate)
        self.assertEqual(cursor._Cursor__compile_re,
                         cursor2._Cursor__compile_re)
        self.assertEqual(cursor._Cursor__query_flags,
                         cursor2._Cursor__query_flags)
        self.assertEqual(cursor._Cursor__comment,
                         cursor2._Cursor__comment)
        self.assertEqual(cursor._Cursor__min,
                         cursor2._Cursor__min)
        self.assertEqual(cursor._Cursor__max,
                         cursor2._Cursor__max)

        # Shallow copies can so can mutate
        cursor2 = copy.copy(cursor)
        cursor2._Cursor__fields['cursor2'] = False
        self.assertTrue('cursor2' in cursor._Cursor__fields)

        # Deepcopies and shouldn't mutate
        cursor3 = copy.deepcopy(cursor)
        cursor3._Cursor__fields['cursor3'] = False
        self.assertFalse('cursor3' in cursor._Cursor__fields)

        cursor4 = cursor.clone()
        cursor4._Cursor__fields['cursor4'] = False
        self.assertFalse('cursor4' in cursor._Cursor__fields)

        # Test memo when deepcopying queries
        query = {"hello": "world"}
        query["reflexive"] = query
        cursor = self.db.test.find(query)

        cursor2 = copy.deepcopy(cursor)

        self.assertNotEqual(id(cursor._Cursor__spec),
                            id(cursor2._Cursor__spec))
        self.assertEqual(id(cursor2._Cursor__spec['reflexive']),
                         id(cursor2._Cursor__spec))
        self.assertEqual(len(cursor2._Cursor__spec), 2)

        # Ensure hints are cloned as the correct type
        cursor = self.db.test.find().hint([('z', 1), ("a", 1)])
        cursor2 = copy.deepcopy(cursor)
        self.assertTrue(isinstance(cursor2._Cursor__hint, SON))
        self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint)

    def test_deepcopy_cursor_littered_with_regexes(self):

        cursor = self.db.test.find({"x": re.compile("^hmmm.*"),
                                    "y": [re.compile("^hmm.*")],
                                    "z": {"a": [re.compile("^hm.*")]},
                                    re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})

        cursor2 = copy.deepcopy(cursor)
        self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)

    def test_add_remove_option(self):
        cursor = self.db.test.find()
        self.assertEqual(0, cursor._Cursor__query_options())
        cursor.add_option(2)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(32)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.add_option(128)
        cursor2 = self.db.test.find(tailable=True,
                                    await_data=True).add_option(128)
        self.assertEqual(162, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(162, cursor._Cursor__query_options())
        cursor.add_option(128)
        self.assertEqual(162, cursor._Cursor__query_options())

        cursor.remove_option(128)
        cursor2 = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        cursor2 = self.db.test.find(tailable=True)
        self.assertEqual(2, cursor2._Cursor__query_options())
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())

        self.assertEqual(2, cursor._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

        # Slave OK
        cursor = self.db.test.find(slave_okay=True)
        self.assertEqual(4, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(4)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        self.assertTrue(cursor._Cursor__slave_okay)
        cursor.remove_option(4)
        self.assertEqual(0, cursor._Cursor__query_options())
        self.assertFalse(cursor._Cursor__slave_okay)

        # Timeout
        cursor = self.db.test.find(timeout=False)
        self.assertEqual(16, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(16)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(16)
        self.assertEqual(0, cursor._Cursor__query_options())

        # Tailable / Await data
        cursor = self.db.test.find(tailable=True, await_data=True)
        self.assertEqual(34, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(34)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(32)
        self.assertEqual(2, cursor._Cursor__query_options())

        # Exhaust - which mongos doesn't support
        if not is_mongos(self.db.connection):
            cursor = self.db.test.find(exhaust=True)
            self.assertEqual(64, cursor._Cursor__query_options())
            cursor2 = self.db.test.find().add_option(64)
            self.assertEqual(cursor._Cursor__query_options(),
                             cursor2._Cursor__query_options())
            self.assertTrue(cursor._Cursor__exhaust)
            cursor.remove_option(64)
            self.assertEqual(0, cursor._Cursor__query_options())
            self.assertFalse(cursor._Cursor__exhaust)

        # Partial
        cursor = self.db.test.find(partial=True)
        self.assertEqual(128, cursor._Cursor__query_options())
        cursor2 = self.db.test.find().add_option(128)
        self.assertEqual(cursor._Cursor__query_options(),
                         cursor2._Cursor__query_options())
        cursor.remove_option(128)
        self.assertEqual(0, cursor._Cursor__query_options())

    def test_count_with_fields(self):
        self.db.test.drop()
        self.db.test.save({"x": 1})

        if not version.at_least(self.db.connection, (1, 1, 3, -1)):
            for _ in self.db.test.find({}, ["a"]):
                self.fail()

            self.assertEqual(0, self.db.test.find({}, ["a"]).count())
        else:
            self.assertEqual(1, self.db.test.find({}, ["a"]).count())

    def test_bad_getitem(self):
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
        self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)

    def test_getitem_slice_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        count = itertools.count

        self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
        self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])

        for a, b in zip(count(0), self.db.test.find()):
            self.assertEqual(a, b['i'])

        self.assertEqual(100, len(list(self.db.test.find()[0:])))
        for a, b in zip(count(0), self.db.test.find()[0:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[20:])))
        for a, b in zip(count(20), self.db.test.find()[20:]):
            self.assertEqual(a, b['i'])

        for a, b in zip(count(99), self.db.test.find()[99:]):
            self.assertEqual(a, b['i'])

        for i in self.db.test.find()[1000:]:
            self.fail()

        self.assertEqual(5, len(list(self.db.test.find()[20:25])))
        self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
        for a, b in zip(count(20), self.db.test.find()[20:25]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
        for a, b in zip(count(20), self.db.test.find()[40:45][20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find()[40:45].limit(0).skip(20))
                            )
                        )
        for a, b in zip(count(20),
                         self.db.test.find()[40:45].limit(0).skip(20)):
            self.assertEqual(a, b['i'])

        self.assertEqual(80,
                         len(list(self.db.test.find().limit(10).skip(40)[20:]))
                        )
        for a, b in zip(count(20),
                         self.db.test.find().limit(10).skip(40)[20:]):
            self.assertEqual(a, b['i'])

        self.assertEqual(1, len(list(self.db.test.find()[:1])))
        self.assertEqual(5, len(list(self.db.test.find()[:5])))

        self.assertEqual(1, len(list(self.db.test.find()[99:100])))
        self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
        self.assertEqual(0, len(list(self.db.test.find()[10:10])))
        self.assertEqual(0, len(list(self.db.test.find()[:0])))
        self.assertEqual(80,
                         len(list(self.db.test.find()[10:10].limit(0).skip(20))
                            )
                        )

        self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])

    def test_getitem_numeric_index(self):
        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        self.assertEqual(0, self.db.test.find()[0]['i'])
        self.assertEqual(50, self.db.test.find()[50]['i'])
        self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
        self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
        self.assertEqual(50, self.db.test.find()[50L]['i'])
        self.assertEqual(99, self.db.test.find()[99]['i'])

        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
        self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
        self.assertRaises(IndexError,
                          lambda x: self.db.test.find().skip(50)[x], 50)

    def test_count_with_limit_and_skip(self):
        if not version.at_least(self.db.connection, (1, 1, 4, -1)):
            raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4")

        self.assertRaises(TypeError, self.db.test.find().count, "foo")

        def check_len(cursor, length):
            self.assertEqual(len(list(cursor)), cursor.count(True))
            self.assertEqual(length, cursor.count(True))

        self.db.drop_collection("test")
        for i in range(100):
            self.db.test.save({"i": i})

        check_len(self.db.test.find(), 100)

        check_len(self.db.test.find().limit(10), 10)
        check_len(self.db.test.find().limit(110), 100)

        check_len(self.db.test.find().skip(10), 90)
        check_len(self.db.test.find().skip(110), 0)

        check_len(self.db.test.find().limit(10).skip(10), 10)
        check_len(self.db.test.find()[10:20], 10)
        check_len(self.db.test.find().limit(10).skip(95), 5)
        check_len(self.db.test.find()[95:105], 5)

    def test_len(self):
        self.assertRaises(TypeError, len, self.db.test.find())

    def test_properties(self):
        self.assertEqual(self.db.test, self.db.test.find().collection)

        def set_coll():
            self.db.test.find().collection = "hello"

        self.assertRaises(AttributeError, set_coll)

    def test_get_more(self):
        db = self.db
        db.drop_collection("test")
        db.test.insert([{'i': i} for i in range(10)])
        self.assertEqual(10, len(list(db.test.find().batch_size(5))))

    def test_tailable(self):
        db = self.db
        db.drop_collection("test")
        db.create_collection("test", capped=True, size=1000, max=3)

        try:
            cursor = db.test.find(tailable=True)

            db.test.insert({"x": 1})
            count = 0
            for doc in cursor:
                count += 1
                self.assertEqual(1, doc["x"])
            self.assertEqual(1, count)

            db.test.insert({"x": 2})
            count = 0
            for doc in cursor:
                count += 1
                self.assertEqual(2, doc["x"])
            self.assertEqual(1, count)

            db.test.insert({"x": 3})
            count = 0
            for doc in cursor:
                count += 1
                self.assertEqual(3, doc["x"])
            self.assertEqual(1, count)

            # Capped rollover - the collection can never
            # have more than 3 documents. Just make sure
            # this doesn't raise...
            db.test.insert(({"x": i} for i in xrange(4, 7)))
            self.assertEqual(0, len(list(cursor)))

            # and that the cursor doesn't think it's still alive.
            self.assertFalse(cursor.alive)

            self.assertEqual(3, db.test.count())
        finally:
            db.drop_collection("test")

    def test_distinct(self):
        if not version.at_least(self.db.connection, (1, 1, 3, 1)):
            raise SkipTest("distinct with query requires MongoDB >= 1.1.3")

        self.db.drop_collection("test")

        self.db.test.save({"a": 1})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 2})
        self.db.test.save({"a": 3})

        distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
        distinct.sort()

        self.assertEqual([1, 2], distinct)

        self.db.drop_collection("test")

        self.db.test.save({"a": {"b": "a"}, "c": 12})
        self.db.test.save({"a": {"b": "b"}, "c": 8})
        self.db.test.save({"a": {"b": "c"}, "c": 12})
        self.db.test.save({"a": {"b": "c"}, "c": 8})

        distinct = self.db.test.find({"c": 8}).distinct("a.b")
        distinct.sort()

        self.assertEqual(["b", "c"], distinct)

    def test_max_scan(self):
        if not version.at_least(self.db.connection, (1, 5, 1)):
            raise SkipTest("maxScan requires MongoDB >= 1.5.1")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        self.assertEqual(100, len(list(self.db.test.find())))
        self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
        self.assertEqual(50, len(list(self.db.test.find()
                                      .max_scan(90).max_scan(50))))

    def test_with_statement(self):
        if sys.version_info < (2, 6):
            raise SkipTest("With statement requires Python >= 2.6")

        self.db.drop_collection("test")
        for _ in range(100):
            self.db.test.insert({})

        c1 = self.db.test.find()
        exec """
with self.db.test.find() as c2:
    self.assertTrue(c2.alive)
self.assertFalse(c2.alive)

with self.db.test.find() as c2:
    self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
        self.assertTrue(c1.alive)

    def test_comment(self):
        if is_mongos(self.client):
            raise SkipTest("profile is not supported by mongos")
        if not version.at_least(self.db.connection, (2, 0)):
            raise SkipTest("Requires server >= 2.0")
        if server_started_with_auth(self.db.connection):
            raise SkipTest("SERVER-4754 - This test uses profiling.")

        def run_with_profiling(func):
            self.db.set_profiling_level(OFF)
            self.db.system.profile.drop()
            self.db.set_profiling_level(ALL)
            func()
            self.db.set_profiling_level(OFF)

        def find():
            list(self.db.test.find().comment('foo'))
            op = self.db.system.profile.find({'ns': 'pymongo_test.test',
                                              'op': 'query',
                                              'query.$comment': 'foo'})
            self.assertEqual(op.count(), 1)

        run_with_profiling(find)

        def count():
            self.db.test.find().comment('foo').count()
            op = self.db.system.profile.find({'ns': 'pymongo_test.$cmd',
                                              'op': 'command',
                                              'command.count': 'test',
                                              'command.$comment': 'foo'})
            self.assertEqual(op.count(), 1)

        run_with_profiling(count)

        def distinct():
            self.db.test.find().comment('foo').distinct('type')
            op = self.db.system.profile.find({'ns': 'pymongo_test.$cmd',
                                              'op': 'command',
                                              'command.distinct': 'test',
                                              'command.$comment': 'foo'})
            self.assertEqual(op.count(), 1)

        run_with_profiling(distinct)

        self.db.test.insert([{}, {}])
        cursor = self.db.test.find()
        cursor.next()
        self.assertRaises(InvalidOperation, cursor.comment, 'hello')

        self.db.system.profile.drop()

    def test_cursor_transfer(self):

        # This is just a test, don't try this at home...
        self.db.test.remove({})
        self.db.test.insert({'_id': i} for i in xrange(200))

        class CManager(CursorManager):
            def __init__(self, connection):
                super(CManager, self).__init__(connection)

            def close(self, dummy):
                # Do absolutely nothing...
                pass

        client = self.db.connection
        ctx = catch_warnings()
        try:
            warnings.simplefilter("ignore", DeprecationWarning)
            client.set_cursor_manager(CManager)

            docs = []
            cursor = self.db.test.find().batch_size(10)
            docs.append(cursor.next())
            cursor.close()
            docs.extend(cursor)
            self.assertEqual(len(docs), 10)
            cmd_cursor = {'id': cursor.cursor_id, 'firstBatch': []}
            ccursor = CommandCursor(cursor.collection, cmd_cursor,
                                    cursor.conn_id, retrieved=cursor.retrieved)
            docs.extend(ccursor)
            self.assertEqual(len(docs), 200)
        finally:
            client.set_cursor_manager(CursorManager)
            ctx.exit()