Exemple #1
0
    def test_survive_cursor_not_found(self):
        # By default the find command returns 101 documents in the first batch.
        # Use 102 batches to cause a single getMore.
        chunk_size = 1024
        data = b'd' * (102 * chunk_size)
        listener = EventListener()
        client = rs_or_single_client(event_listeners=[listener])
        db = client.pymongo_test
        with GridIn(db.fs, chunk_size=chunk_size) as infile:
            infile.write(data)

        with GridOut(db.fs, infile._id) as outfile:
            self.assertEqual(len(outfile.readchunk()), chunk_size)

            # Kill the cursor to simulate the cursor timing out on the server
            # when an application spends a long time between two calls to
            # readchunk().
            client._close_cursor_now(
                outfile._GridOut__chunk_iter._cursor.cursor_id,
                _CursorAddress(client.address, db.fs.chunks.full_name))

            # Read the rest of the file without error.
            self.assertEqual(len(outfile.read()), len(data) - chunk_size)

        # Paranoid, ensure that a getMore was actually sent.
        self.assertIn("getMore", listener.started_command_names())
    def test_survive_cursor_not_found(self):
        # By default the find command returns 101 documents in the first batch.
        # Use 102 batches to cause a single getMore.
        chunk_size = 1024
        data = b'd' * (102 * chunk_size)
        listener = EventListener()
        client = rs_or_single_client(event_listeners=[listener])
        db = client.pymongo_test
        with GridIn(db.fs, chunk_size=chunk_size) as infile:
            infile.write(data)

        with GridOut(db.fs, infile._id) as outfile:
            self.assertEqual(len(outfile.readchunk()), chunk_size)

            # Kill the cursor to simulate the cursor timing out on the server
            # when an application spends a long time between two calls to
            # readchunk().
            client._close_cursor_now(
                outfile._GridOut__chunk_iter._cursor.cursor_id,
                _CursorAddress(client.address, db.fs.chunks.full_name))

            # Read the rest of the file without error.
            self.assertEqual(len(outfile.read()), len(data) - chunk_size)

        # Paranoid, ensure that a getMore was actually sent.
        self.assertIn("getMore", listener.started_command_names())
 def setUpClass(cls):
     cls.listener = EventListener()
     cls.saved_listeners = monitoring._LISTENERS
     monitoring._LISTENERS = monitoring._Listeners([], [], [], [])
     cls.client = rs_or_single_client(event_listeners=[cls.listener])
     cls.db = cls.client.pymongo_test
     cls.collation = Collation('en_US')
 def test_batch_size_is_honored(self):
     listener = EventListener()
     client = rs_or_single_client(event_listeners=[listener])
     # Connect to the cluster.
     client.admin.command('ping')
     listener.results.clear()
     # ChangeStreams only read majority committed data so use w:majority.
     coll = self.watched_collection().with_options(
         write_concern=WriteConcern("majority"))
     coll.drop()
     # Create the watched collection before starting the change stream to
     # skip any "create" events.
     coll.insert_one({'_id': 1})
     self.addCleanup(coll.drop)
     # Expected batchSize.
     expected = {'batchSize': 23}
     with self.change_stream_with_client(
             client, max_await_time_ms=250, batch_size=23) as stream:
         # Confirm that batchSize is honored for initial batch.
         cmd = listener.results['started'][0].command
         self.assertEqual(cmd['cursor'], expected)
         listener.results.clear()
         # Confirm that batchSize is honored by getMores.
         self.assertIsNone(stream.try_next())
         cmd = listener.results['started'][0].command
         key = next(iter(expected))
         self.assertEqual(expected[key], cmd[key])
Exemple #5
0
 def setUpClass(cls):
     cls.listener = EventListener()
     cls.saved_listeners = monitoring._LISTENERS
     # Don't use any global subscribers.
     monitoring._LISTENERS = monitoring._Listeners([], [], [], [])
     cls.client = single_client(event_listeners=[cls.listener])
     cls.db = cls.client.pymongo_test
 def setUpClass(cls):
     cls.listener = EventListener()
     cls.client = rs_or_single_client(event_listeners=[cls.listener])
     cls.db = cls.client.pymongo_test
     cls.collation = Collation('en_US')
     cls.warn_context = warnings.catch_warnings()
     cls.warn_context.__enter__()
     warnings.simplefilter("ignore", DeprecationWarning)
    def test_omit_default_read_write_concern(self):
        listener = EventListener()
        # Client with default readConcern and writeConcern
        client = rs_or_single_client(event_listeners=[listener])
        collection = client.pymongo_test.collection
        # Prepare for tests of find() and aggregate().
        collection.insert_many([{} for _ in range(10)])
        self.addCleanup(collection.drop)
        self.addCleanup(client.pymongo_test.collection2.drop)

        # Commands MUST NOT send the default read/write concern to the server.

        def rename_and_drop():
            # Ensure collection exists.
            collection.insert_one({})
            collection.rename('collection2')
            client.pymongo_test.collection2.drop()

        def insert_command_default_write_concern():
            collection.database.command('insert',
                                        'collection',
                                        documents=[{}],
                                        write_concern=WriteConcern())

        ops = [('aggregate', lambda: list(collection.aggregate([]))),
               ('find', lambda: list(collection.find())),
               ('insert_one', lambda: collection.insert_one({})),
               ('update_one',
                lambda: collection.update_one({}, {'$set': {
                    'x': 1
                }})),
               ('update_many',
                lambda: collection.update_many({}, {'$set': {
                    'x': 1
                }})), ('delete_one', lambda: collection.delete_one({})),
               ('delete_many', lambda: collection.delete_many({})),
               ('bulk_write', lambda: collection.bulk_write([InsertOne({})])),
               ('rename_and_drop', rename_and_drop),
               ('command', insert_command_default_write_concern)]

        for name, f in ops:
            listener.results.clear()
            f()

            self.assertGreaterEqual(len(listener.results['started']), 1)
            for i, event in enumerate(listener.results['started']):
                self.assertNotIn(
                    'readConcern', event.command,
                    "%s sent default readConcern with %s" %
                    (name, event.command_name))
                self.assertNotIn(
                    'writeConcern', event.command,
                    "%s sent default writeConcern with %s" %
                    (name, event.command_name))
Exemple #8
0
    def test_try_next_runs_one_getmore(self):
        listener = EventListener()
        client = rs_or_single_client(event_listeners=[listener])
        # Connect to the cluster.
        client.admin.command('ping')
        listener.results.clear()
        # ChangeStreams only read majority committed data so use w:majority.
        coll = self.watched_collection().with_options(
            write_concern=WriteConcern("majority"))
        coll.drop()
        # Create the watched collection before starting the change stream to
        # skip any "create" events.
        coll.insert_one({'_id': 1})
        self.addCleanup(coll.drop)
        with self.change_stream_with_client(client,
                                            max_await_time_ms=250) as stream:
            self.assertEqual(listener.started_command_names(), ["aggregate"])
            listener.results.clear()

            # Confirm that only a single getMore is run even when no documents
            # are returned.
            self.assertIsNone(stream.try_next())
            self.assertEqual(listener.started_command_names(), ["getMore"])
            listener.results.clear()
            self.assertIsNone(stream.try_next())
            self.assertEqual(listener.started_command_names(), ["getMore"])
            listener.results.clear()

            # Get at least one change before resuming.
            coll.insert_one({'_id': 2})
            change = stream.try_next()
            self.assertEqual(change['_id'], stream._resume_token)
            listener.results.clear()

            # Cause the next request to initiate the resume process.
            self.kill_change_stream_cursor(stream)
            listener.results.clear()

            # The sequence should be:
            # - getMore, fail
            # - resume with aggregate command
            # - no results, return immediately without another getMore
            self.assertIsNone(stream.try_next())
            self.assertEqual(listener.started_command_names(),
                             ["getMore", "aggregate"])
            listener.results.clear()

            # Stream still works after a resume.
            coll.insert_one({'_id': 3})
            change = stream.try_next()
            self.assertEqual(change['_id'], stream._resume_token)
            self.assertEqual(listener.started_command_names(), ["getMore"])
            self.assertIsNone(stream.try_next())
    def test_try_next_runs_one_getmore(self):
        listener = EventListener()
        client = rs_or_single_client(event_listeners=[listener])
        # Connect to the cluster.
        client.admin.command('ping')
        listener.results.clear()
        # ChangeStreams only read majority committed data so use w:majority.
        coll = self.watched_collection().with_options(
            write_concern=WriteConcern("majority"))
        coll.drop()
        # Create the watched collection before starting the change stream to
        # skip any "create" events.
        coll.insert_one({'_id': 1})
        self.addCleanup(coll.drop)
        with self.change_stream_with_client(
                client, max_await_time_ms=250) as stream:
            self.assertEqual(listener.started_command_names(), ["aggregate"])
            listener.results.clear()

            # Confirm that only a single getMore is run even when no documents
            # are returned.
            self.assertIsNone(stream.try_next())
            self.assertEqual(listener.started_command_names(), ["getMore"])
            listener.results.clear()
            self.assertIsNone(stream.try_next())
            self.assertEqual(listener.started_command_names(), ["getMore"])
            listener.results.clear()

            # Get at least one change before resuming.
            coll.insert_one({'_id': 2})
            change = stream.try_next()
            self.assertEqual(change['_id'], stream._resume_token)
            listener.results.clear()

            # Cause the next request to initiate the resume process.
            self.kill_change_stream_cursor(stream)
            listener.results.clear()

            # The sequence should be:
            # - getMore, fail
            # - resume with aggregate command
            # - no results, return immediately without another getMore
            self.assertIsNone(stream.try_next())
            self.assertEqual(
                listener.started_command_names(), ["getMore", "aggregate"])
            listener.results.clear()

            # Stream still works after a resume.
            coll.insert_one({'_id': 3})
            change = stream.try_next()
            self.assertEqual(change['_id'], stream._resume_token)
            self.assertEqual(listener.started_command_names(), ["getMore"])
            self.assertIsNone(stream.try_next())
 def test_write_error_details_exposes_errinfo(self):
     listener = EventListener()
     client = rs_or_single_client(event_listeners=[listener])
     db = client.errinfotest
     self.addCleanup(client.drop_database, "errinfotest")
     validator = {"x": {"$type": "string"}}
     db.create_collection("test", validator=validator)
     with self.assertRaises(WriteError) as ctx:
         db.test.insert_one({'x': 1})
     self.assertEqual(ctx.exception.code, 121)
     self.assertIsNotNone(ctx.exception.details)
     self.assertIsNotNone(ctx.exception.details.get('errInfo'))
     for event in listener.results['succeeded']:
         if event.command_name == 'insert':
             self.assertEqual(
                 event.reply['writeErrors'][0], ctx.exception.details)
             break
     else:
         self.fail("Couldn't find insert event.")
    def test_functional_select_max_port_number_host(self):
        # Selector that returns server with highest port number.
        def custom_selector(servers):
            ports = [s.address[1] for s in servers]
            idx = ports.index(max(ports))
            return [servers[idx]]

        # Initialize client with appropriate listeners.
        listener = EventListener()
        client = rs_or_single_client(server_selector=custom_selector,
                                     event_listeners=[listener])
        self.addCleanup(client.close)
        coll = client.get_database('testdb',
                                   read_preference=ReadPreference.NEAREST).coll
        self.addCleanup(client.drop_database, 'testdb')

        # Wait the node list to be fully populated.
        def all_hosts_started():
            return (len(client.admin.command('isMaster')['hosts']) == len(
                client._topology._description.readable_servers))

        wait_until(all_hosts_started, 'receive heartbeat from all hosts')
        expected_port = max([
            n.address[1]
            for n in client._topology._description.readable_servers
        ])

        # Insert 1 record and access it 10 times.
        coll.insert_one({'name': 'John Doe'})
        for _ in range(10):
            coll.find_one({'name': 'John Doe'})

        # Confirm all find commands are run against appropriate host.
        for command in listener.results['started']:
            if command.command_name == 'find':
                self.assertEqual(command.connection_id[1], expected_port)
Exemple #12
0
 def setUpClass(cls):
     cls.listener = EventListener()
     cls.client = rs_or_single_client(event_listeners=[cls.listener])
    def test_max_await_time_ms(self):
        db = self.db
        db.pymongo_test.drop()
        coll = db.create_collection("pymongo_test", capped=True, size=4096)

        self.assertRaises(TypeError, coll.find().max_await_time_ms, 'foo')
        coll.insert_one({"amalia": 1})
        coll.insert_one({"amalia": 2})

        coll.find().max_await_time_ms(None)
        coll.find().max_await_time_ms(long(1))

        # When cursor is not tailable_await
        cursor = coll.find()
        self.assertEqual(None, cursor._Cursor__max_await_time_ms)
        cursor = coll.find().max_await_time_ms(99)
        self.assertEqual(None, cursor._Cursor__max_await_time_ms)

        # If cursor is tailable_await and timeout is unset
        cursor = coll.find(cursor_type=CursorType.TAILABLE_AWAIT)
        self.assertEqual(None, cursor._Cursor__max_await_time_ms)

        # If cursor is tailable_await and timeout is set
        cursor = coll.find(
            cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99)
        self.assertEqual(99, cursor._Cursor__max_await_time_ms)

        cursor = coll.find(
            cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(
                10).max_await_time_ms(90)
        self.assertEqual(90, cursor._Cursor__max_await_time_ms)

        listener = EventListener()
        listener.add_command_filter('killCursors')
        saved_listeners = monitoring._LISTENERS
        monitoring._LISTENERS = monitoring._Listeners([], [], [], [])
        coll = rs_or_single_client(
            event_listeners=[listener])[self.db.name].pymongo_test
        results = listener.results

        try:
            # Tailable_await defaults.
            list(coll.find(cursor_type=CursorType.TAILABLE_AWAIT))
            # find
            self.assertFalse('maxTimeMS' in results['started'][0].command)
            # getMore
            self.assertFalse('maxTimeMS' in results['started'][1].command)
            results.clear()

            # Tailable_await with max_await_time_ms set.
            list(coll.find(
                cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertFalse('maxTimeMS' in results['started'][0].command)
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertTrue('maxTimeMS' in results['started'][1].command)
            self.assertEqual(99, results['started'][1].command['maxTimeMS'])
            results.clear()

            # Tailable_await with max_time_ms
            list(coll.find(
                cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(1))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertTrue('maxTimeMS' in results['started'][0].command)
            self.assertEqual(1, results['started'][0].command['maxTimeMS'])
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertFalse('maxTimeMS' in results['started'][1].command)
            results.clear()

            # Tailable_await with both max_time_ms and max_await_time_ms
            list(coll.find(
                cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(
                    1).max_await_time_ms(99))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertTrue('maxTimeMS' in results['started'][0].command)
            self.assertEqual(1, results['started'][0].command['maxTimeMS'])
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertTrue('maxTimeMS' in results['started'][1].command)
            self.assertEqual(99, results['started'][1].command['maxTimeMS'])
            results.clear()

            # Non tailable_await with max_await_time_ms
            list(coll.find(batch_size=1).max_await_time_ms(99))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertFalse('maxTimeMS' in results['started'][0].command)
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertFalse('maxTimeMS' in results['started'][1].command)
            results.clear()

            # Non tailable_await with max_time_ms
            list(coll.find(batch_size=1).max_time_ms(99))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertTrue('maxTimeMS' in results['started'][0].command)
            self.assertEqual(99, results['started'][0].command['maxTimeMS'])
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertFalse('maxTimeMS' in results['started'][1].command)

            # Non tailable_await with both max_time_ms and max_await_time_ms
            list(coll.find(batch_size=1).max_time_ms(99).max_await_time_ms(88))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertTrue('maxTimeMS' in results['started'][0].command)
            self.assertEqual(99, results['started'][0].command['maxTimeMS'])
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertFalse('maxTimeMS' in results['started'][1].command)

        finally:
            monitoring._LISTENERS = saved_listeners
 def setUpClass(cls):
     cls.listener = EventListener()
     cls.saved_listeners = monitoring._LISTENERS
     monitoring._LISTENERS = monitoring._Listeners([], [], [], [])
     cls.client = single_client(event_listeners=[cls.listener])
Exemple #15
0
    def test_authenticate_multiple(self):
        # "self.client" is logged in as root.
        self.client.drop_database("pymongo_test")
        self.client.drop_database("pymongo_test1")
        admin_db_auth = self.client.admin
        users_db_auth = self.client.pymongo_test

        admin_db_auth.add_user('ro-admin',
                               'pass',
                               roles=["userAdmin", "readAnyDatabase"])

        self.addCleanup(client_context.drop_user, 'admin', 'ro-admin')
        users_db_auth.add_user('user',
                               'pass',
                               roles=["userAdmin", "readWrite"])
        self.addCleanup(remove_all_users, users_db_auth)

        # Non-root client.
        listener = EventListener()
        client = rs_or_single_client_noauth(event_listeners=[listener])
        admin_db = client.admin
        users_db = client.pymongo_test
        other_db = client.pymongo_test1

        self.assertRaises(OperationFailure, users_db.test.find_one)
        self.assertEqual(listener.started_command_names(), ['find'])
        listener.reset()

        # Regular user should be able to query its own db, but
        # no other.
        users_db.authenticate('user', 'pass')
        if client_context.version.at_least(3, 0):
            self.assertEqual(listener.started_command_names()[0], 'saslStart')
        else:
            self.assertEqual(listener.started_command_names()[0], 'getnonce')

        self.assertEqual(0, users_db.test.count_documents({}))
        self.assertRaises(OperationFailure, other_db.test.find_one)

        listener.reset()
        # Admin read-only user should be able to query any db,
        # but not write.
        admin_db.authenticate('ro-admin', 'pass')
        if client_context.version.at_least(3, 0):
            self.assertEqual(listener.started_command_names()[0], 'saslStart')
        else:
            self.assertEqual(listener.started_command_names()[0], 'getnonce')
        self.assertEqual(None, other_db.test.find_one())
        self.assertRaises(OperationFailure, other_db.test.insert_one, {})

        # Close all sockets.
        client.close()

        listener.reset()
        # We should still be able to write to the regular user's db.
        self.assertTrue(users_db.test.delete_many({}))
        names = listener.started_command_names()
        if client_context.version.at_least(4, 4, -1):
            # No speculation with multiple users (but we do skipEmptyExchange).
            self.assertEqual(names, [
                'saslStart', 'saslContinue', 'saslStart', 'saslContinue',
                'delete'
            ])
        elif client_context.version.at_least(3, 0):
            self.assertEqual(names, [
                'saslStart', 'saslContinue', 'saslContinue', 'saslStart',
                'saslContinue', 'saslContinue', 'delete'
            ])
        else:
            self.assertEqual(names, [
                'getnonce', 'authenticate', 'getnonce', 'authenticate',
                'delete'
            ])

        # And read from other dbs...
        self.assertEqual(0, other_db.test.count_documents({}))

        # But still not write to other dbs.
        self.assertRaises(OperationFailure, other_db.test.insert_one, {})
Exemple #16
0
    def test_mongodb_x509_auth(self):
        host, port = client_context.host, client_context.port
        ssl_client = MongoClient(client_context.pair,
                                 ssl=True,
                                 ssl_cert_reqs=ssl.CERT_NONE,
                                 ssl_certfile=CLIENT_PEM)
        self.addCleanup(remove_all_users, ssl_client['$external'])

        ssl_client.admin.authenticate(db_user, db_pwd)

        # Give x509 user all necessary privileges.
        client_context.create_user('$external',
                                   MONGODB_X509_USERNAME,
                                   roles=[{
                                       'role': 'readWriteAnyDatabase',
                                       'db': 'admin'
                                   }, {
                                       'role': 'userAdminAnyDatabase',
                                       'db': 'admin'
                                   }])

        noauth = MongoClient(client_context.pair,
                             ssl=True,
                             ssl_cert_reqs=ssl.CERT_NONE,
                             ssl_certfile=CLIENT_PEM)

        self.assertRaises(OperationFailure, noauth.pymongo_test.test.count)

        listener = EventListener()
        auth = MongoClient(client_context.pair,
                           authMechanism='MONGODB-X509',
                           ssl=True,
                           ssl_cert_reqs=ssl.CERT_NONE,
                           ssl_certfile=CLIENT_PEM,
                           event_listeners=[listener])

        if client_context.version.at_least(3, 3, 12):
            # No error
            auth.pymongo_test.test.find_one()
            names = listener.started_command_names()
            if client_context.version.at_least(4, 4, -1):
                # Speculative auth skips the authenticate command.
                self.assertEqual(names, ['find'])
            else:
                self.assertEqual(names, ['authenticate', 'find'])
        else:
            # Should require a username
            with self.assertRaises(ConfigurationError):
                auth.pymongo_test.test.find_one()

        uri = ('mongodb://%s@%s:%d/?authMechanism='
               'MONGODB-X509' %
               (quote_plus(MONGODB_X509_USERNAME), host, port))
        client = MongoClient(uri,
                             ssl=True,
                             ssl_cert_reqs=ssl.CERT_NONE,
                             ssl_certfile=CLIENT_PEM)
        # No error
        client.pymongo_test.test.find_one()

        uri = 'mongodb://%s:%d/?authMechanism=MONGODB-X509' % (host, port)
        client = MongoClient(uri,
                             ssl=True,
                             ssl_cert_reqs=ssl.CERT_NONE,
                             ssl_certfile=CLIENT_PEM)
        if client_context.version.at_least(3, 3, 12):
            # No error
            client.pymongo_test.test.find_one()
        else:
            # Should require a username
            with self.assertRaises(ConfigurationError):
                client.pymongo_test.test.find_one()

        # Auth should fail if username and certificate do not match
        uri = ('mongodb://%s@%s:%d/?authMechanism='
               'MONGODB-X509' % (quote_plus("not the username"), host, port))

        bad_client = MongoClient(uri,
                                 ssl=True,
                                 ssl_cert_reqs="CERT_NONE",
                                 ssl_certfile=CLIENT_PEM)

        with self.assertRaises(OperationFailure):
            bad_client.pymongo_test.test.find_one()

        bad_client = MongoClient(client_context.pair,
                                 username="******",
                                 authMechanism='MONGODB-X509',
                                 ssl=True,
                                 ssl_cert_reqs=ssl.CERT_NONE,
                                 ssl_certfile=CLIENT_PEM)

        with self.assertRaises(OperationFailure):
            bad_client.pymongo_test.test.find_one()

        # Invalid certificate (using CA certificate as client certificate)
        uri = ('mongodb://%s@%s:%d/?authMechanism='
               'MONGODB-X509' %
               (quote_plus(MONGODB_X509_USERNAME), host, port))
        try:
            connected(
                MongoClient(uri,
                            ssl=True,
                            ssl_cert_reqs=ssl.CERT_NONE,
                            ssl_certfile=CA_PEM,
                            serverSelectionTimeoutMS=100))
        except (ConnectionFailure, ConfigurationError):
            pass
        else:
            self.fail("Invalid certificate accepted.")
    def run_scenario(self):
        listener = EventListener()
        # New client, to avoid interference from pooled sessions.
        # Convert test['clientOptions'] to dict to avoid a Jython bug using "**"
        # with ScenarioDict.
        client = rs_client(event_listeners=[listener],
                           **dict(test['clientOptions']))
        try:
            client.admin.command('killAllSessions', [])
        except OperationFailure:
            # "operation was interrupted" by killing the command's own session.
            pass

        write_concern_db = client.get_database(
            'transaction-tests', write_concern=WriteConcern(w='majority'))

        write_concern_db.test.drop()
        write_concern_db.create_collection('test')
        if scenario_def['data']:
            # Load data.
            write_concern_db.test.insert_many(scenario_def['data'])

        # Create session0 and session1.
        sessions = {}
        session_ids = {}
        for i in range(2):
            session_name = 'session%d' % i
            opts = camel_to_snake_args(test['sessionOptions'][session_name])
            if 'default_transaction_options' in opts:
                txn_opts = opts['default_transaction_options']
                if 'readConcern' in txn_opts:
                    read_concern = ReadConcern(**dict(txn_opts['readConcern']))
                else:
                    read_concern = None
                if 'writeConcern' in txn_opts:
                    write_concern = WriteConcern(
                        **dict(txn_opts['writeConcern']))
                else:
                    write_concern = None

                txn_opts = client_session.TransactionOptions(
                    read_concern=read_concern, write_concern=write_concern)
                opts['default_transaction_options'] = txn_opts

            s = client.start_session(**dict(opts))

            sessions[session_name] = s
            # Store lsid so we can access it after end_session, in check_events.
            session_ids[session_name] = s.session_id

        self.addCleanup(end_sessions, sessions)

        listener.results.clear()
        collection = client['transaction-tests'].test

        for op in test['operations']:
            expected_result = op.get('result')
            if expect_error_message(expected_result):
                with self.assertRaises(PyMongoError) as context:
                    self.run_operation(sessions, collection, op.copy())

                self.assertIn(expected_result['errorContains'].lower(),
                              str(context.exception).lower())
            elif expect_error_code(expected_result):
                with self.assertRaises(OperationFailure) as context:
                    self.run_operation(sessions, collection, op.copy())

                self.assertEqual(expected_result['errorCodeName'],
                                 context.exception.details.get('codeName'))
            else:
                result = self.run_operation(sessions, collection, op.copy())
                if 'result' in op:
                    self.check_result(expected_result, result)

        for s in sessions.values():
            s.end_session()

        self.check_events(test, listener, session_ids)

        # Assert final state is expected.
        expected_c = test['outcome'].get('collection')
        if expected_c is not None:
            self.assertEqual(list(collection.find()), expected_c['data'])
    def test_max_await_time_ms(self):
        db = self.db
        db.pymongo_test.drop()
        coll = db.create_collection("pymongo_test", capped=True, size=4096)

        self.assertRaises(TypeError, coll.find().max_await_time_ms, 'foo')
        coll.insert_one({"amalia": 1})
        coll.insert_one({"amalia": 2})

        coll.find().max_await_time_ms(None)
        coll.find().max_await_time_ms(long(1))

        # When cursor is not tailable_await
        cursor = coll.find()
        self.assertEqual(None, cursor._Cursor__max_await_time_ms)
        cursor = coll.find().max_await_time_ms(99)
        self.assertEqual(None, cursor._Cursor__max_await_time_ms)

        # If cursor is tailable_await and timeout is unset
        cursor = coll.find(cursor_type=CursorType.TAILABLE_AWAIT)
        self.assertEqual(None, cursor._Cursor__max_await_time_ms)

        # If cursor is tailable_await and timeout is set
        cursor = coll.find(
            cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99)
        self.assertEqual(99, cursor._Cursor__max_await_time_ms)

        cursor = coll.find(cursor_type=CursorType.TAILABLE_AWAIT
                           ).max_await_time_ms(10).max_await_time_ms(90)
        self.assertEqual(90, cursor._Cursor__max_await_time_ms)

        listener = EventListener()
        listener.add_command_filter('killCursors')
        saved_listeners = monitoring._LISTENERS
        monitoring._LISTENERS = monitoring._Listeners([])
        coll = single_client(
            event_listeners=[listener])[self.db.name].pymongo_test
        results = listener.results

        try:
            # Tailable_await defaults.
            list(coll.find(cursor_type=CursorType.TAILABLE_AWAIT))
            # find
            self.assertFalse('maxTimeMS' in results['started'][0].command)
            # getMore
            self.assertFalse('maxTimeMS' in results['started'][1].command)
            results.clear()

            # Tailable_await with max_await_time_ms set.
            list(
                coll.find(cursor_type=CursorType.TAILABLE_AWAIT).
                max_await_time_ms(99))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertFalse('maxTimeMS' in results['started'][0].command)
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertTrue('maxTimeMS' in results['started'][1].command)
            self.assertEqual(99, results['started'][1].command['maxTimeMS'])
            results.clear()

            # Tailable_await with max_time_ms
            list(
                coll.find(
                    cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(1))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertTrue('maxTimeMS' in results['started'][0].command)
            self.assertEqual(1, results['started'][0].command['maxTimeMS'])
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertFalse('maxTimeMS' in results['started'][1].command)
            results.clear()

            # Tailable_await with both max_time_ms and max_await_time_ms
            list(
                coll.find(cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(
                    1).max_await_time_ms(99))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertTrue('maxTimeMS' in results['started'][0].command)
            self.assertEqual(1, results['started'][0].command['maxTimeMS'])
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertTrue('maxTimeMS' in results['started'][1].command)
            self.assertEqual(99, results['started'][1].command['maxTimeMS'])
            results.clear()

            # Non tailable_await with max_await_time_ms
            list(coll.find(batch_size=1).max_await_time_ms(99))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertFalse('maxTimeMS' in results['started'][0].command)
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertFalse('maxTimeMS' in results['started'][1].command)
            results.clear()

            # Non tailable_await with max_time_ms
            list(coll.find(batch_size=1).max_time_ms(99))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertTrue('maxTimeMS' in results['started'][0].command)
            self.assertEqual(99, results['started'][0].command['maxTimeMS'])
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertFalse('maxTimeMS' in results['started'][1].command)

            # Non tailable_await with both max_time_ms and max_await_time_ms
            list(coll.find(batch_size=1).max_time_ms(99).max_await_time_ms(88))
            # find
            self.assertEqual('find', results['started'][0].command_name)
            self.assertTrue('maxTimeMS' in results['started'][0].command)
            self.assertEqual(99, results['started'][0].command['maxTimeMS'])
            # getMore
            self.assertEqual('getMore', results['started'][1].command_name)
            self.assertFalse('maxTimeMS' in results['started'][1].command)

        finally:
            monitoring._LISTENERS = saved_listeners
 def setUpClass(cls):
     cls.listener = EventListener()
     cls.listener.add_command_filter('killCursors')
     cls.saved_listeners = monitoring._LISTENERS
     monitoring._LISTENERS = monitoring._Listeners([])
     cls.client = single_client(event_listeners=[cls.listener])