Ejemplo n.º 1
0
    def run_operation(self, collection, test):
        # Iterate over all operations.
        for opdef in test['operations']:
            # Convert command from CamelCase to pymongo.collection method.
            operation = camel_to_snake(opdef['name'])

            # Get command handle on target entity (collection/database).
            target_object = opdef.get('object', 'collection')
            if target_object == 'database':
                cmd = getattr(collection.database, operation)
            elif target_object == 'collection':
                collection = collection.with_options(**dict(
                    parse_collection_options(opdef.get('collectionOptions',
                                                       {}))))
                cmd = getattr(collection, operation)
            else:
                self.fail("Unknown object name %s" % (target_object, ))

            # Convert arguments to snake_case and handle special cases.
            arguments = opdef['arguments']
            options = arguments.pop("options", {})

            for option_name in options:
                arguments[camel_to_snake(option_name)] = options[option_name]

            if operation == "bulk_write":
                # Parse each request into a bulk write model.
                requests = []
                for request in arguments["requests"]:
                    bulk_model = camel_to_upper_camel(request["name"])
                    bulk_class = getattr(operations, bulk_model)
                    bulk_arguments = camel_to_snake_args(request["arguments"])
                    requests.append(bulk_class(**bulk_arguments))
                arguments["requests"] = requests
            else:
                for arg_name in list(arguments):
                    c2s = camel_to_snake(arg_name)
                    # PyMongo accepts sort as list of tuples.
                    if arg_name == "sort":
                        sort_dict = arguments[arg_name]
                        arguments[arg_name] = list(iteritems(sort_dict))
                    # Named "key" instead not fieldName.
                    if arg_name == "fieldName":
                        arguments["key"] = arguments.pop(arg_name)
                    # Aggregate uses "batchSize", while find uses batch_size.
                    elif arg_name == "batchSize" and operation == "aggregate":
                        continue
                    # Requires boolean returnDocument.
                    elif arg_name == "returnDocument":
                        arguments[c2s] = arguments[arg_name] == "After"
                    else:
                        arguments[c2s] = arguments.pop(arg_name)

            if opdef.get('error') is True:
                with self.assertRaises(PyMongoError):
                    cmd(**arguments)
            else:
                result = cmd(**arguments)
                self.check_result(opdef.get('result'), result)
    def run_operation(self, collection, test):
        # Iterate over all operations.
        for opdef in test['operations']:
            # Convert command from CamelCase to pymongo.collection method.
            operation = camel_to_snake(opdef['name'])

            # Get command handle on target entity (collection/database).
            target_object = opdef.get('object', 'collection')
            if target_object == 'database':
                cmd = getattr(collection.database, operation)
            elif target_object == 'collection':
                collection = collection.with_options(**dict(
                    parse_collection_options(opdef.get(
                        'collectionOptions', {}))))
                cmd = getattr(collection, operation)
            else:
                self.fail("Unknown object name %s" % (target_object,))

            # Convert arguments to snake_case and handle special cases.
            arguments = opdef['arguments']
            options = arguments.pop("options", {})

            for option_name in options:
                arguments[camel_to_snake(option_name)] = options[option_name]

            if operation == "bulk_write":
                # Parse each request into a bulk write model.
                requests = []
                for request in arguments["requests"]:
                    bulk_model = camel_to_upper_camel(request["name"])
                    bulk_class = getattr(operations, bulk_model)
                    bulk_arguments = camel_to_snake_args(request["arguments"])
                    requests.append(bulk_class(**bulk_arguments))
                arguments["requests"] = requests
            else:
                for arg_name in list(arguments):
                    c2s = camel_to_snake(arg_name)
                    # PyMongo accepts sort as list of tuples.
                    if arg_name == "sort":
                        sort_dict = arguments[arg_name]
                        arguments[arg_name] = list(iteritems(sort_dict))
                    # Named "key" instead not fieldName.
                    if arg_name == "fieldName":
                        arguments["key"] = arguments.pop(arg_name)
                    # Aggregate uses "batchSize", while find uses batch_size.
                    elif arg_name == "batchSize" and operation == "aggregate":
                        continue
                    # Requires boolean returnDocument.
                    elif arg_name == "returnDocument":
                        arguments[c2s] = arguments[arg_name] == "After"
                    else:
                        arguments[c2s] = arguments.pop(arg_name)

            if opdef.get('error') is True:
                with self.assertRaises(PyMongoError):
                    cmd(**arguments)
            else:
                result = cmd(**arguments)
                self.check_result(opdef.get('result'), result)
Ejemplo n.º 3
0
 def parse_auto_encrypt_opts(self, opts):
     """Parse clientOptions.autoEncryptOpts."""
     opts = camel_to_snake_args(opts)
     kms_providers = opts['kms_providers']
     if 'aws' in kms_providers:
         kms_providers['aws'] = AWS_CREDS
         if not any(AWS_CREDS.values()):
             self.skipTest('AWS environment credentials are not set')
     if 'key_vault_namespace' not in opts:
         opts['key_vault_namespace'] = 'keyvault.datakeys'
     opts = dict(opts)
     return AutoEncryptionOpts(**opts)
Ejemplo n.º 4
0
def run_operation(collection, test):
    # Convert command from CamelCase to pymongo.collection method.
    operation = camel_to_snake(test['operation']['name'])
    cmd = getattr(collection, operation)

    # Convert arguments to snake_case and handle special cases.
    arguments = test['operation']['arguments']
    options = arguments.pop("options", {})
    for option_name in options:
        arguments[camel_to_snake(option_name)] = options[option_name]
    if operation == "bulk_write":
        # Parse each request into a bulk write model.
        requests = []
        for request in arguments["requests"]:
            bulk_model = camel_to_upper_camel(request["name"])
            bulk_class = getattr(operations, bulk_model)
            bulk_arguments = camel_to_snake_args(request["arguments"])
            requests.append(bulk_class(**bulk_arguments))
        arguments["requests"] = requests
    else:
        for arg_name in list(arguments):
            c2s = camel_to_snake(arg_name)
            # PyMongo accepts sort as list of tuples.
            if arg_name == "sort":
                sort_dict = arguments[arg_name]
                arguments[arg_name] = list(iteritems(sort_dict))
            # Named "key" instead not fieldName.
            if arg_name == "fieldName":
                arguments["key"] = arguments.pop(arg_name)
            # Aggregate uses "batchSize", while find uses batch_size.
            elif arg_name == "batchSize" and operation == "aggregate":
                continue
            # Requires boolean returnDocument.
            elif arg_name == "returnDocument":
                arguments[c2s] = arguments[arg_name] == "After"
            else:
                arguments[c2s] = arguments.pop(arg_name)

    result = cmd(**arguments)

    if operation == "aggregate":
        if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
            out = collection.database[arguments["pipeline"][-1]["$out"]]
            result = out.find()

    if isinstance(result, Cursor) or isinstance(result, CommandCursor):
        return list(result)

    return result
def run_operation(collection, test):
    # Convert command from CamelCase to pymongo.collection method.
    operation = camel_to_snake(test['operation']['name'])
    cmd = getattr(collection, operation)

    # Convert arguments to snake_case and handle special cases.
    arguments = test['operation']['arguments']
    options = arguments.pop("options", {})
    for option_name in options:
        arguments[camel_to_snake(option_name)] = options[option_name]
    if operation == "bulk_write":
        # Parse each request into a bulk write model.
        requests = []
        for request in arguments["requests"]:
            bulk_model = camel_to_upper_camel(request["name"])
            bulk_class = getattr(operations, bulk_model)
            bulk_arguments = camel_to_snake_args(request["arguments"])
            requests.append(bulk_class(**bulk_arguments))
        arguments["requests"] = requests
    else:
        for arg_name in list(arguments):
            c2s = camel_to_snake(arg_name)
            # PyMongo accepts sort as list of tuples.
            if arg_name == "sort":
                sort_dict = arguments[arg_name]
                arguments[arg_name] = list(iteritems(sort_dict))
            # Named "key" instead not fieldName.
            if arg_name == "fieldName":
                arguments["key"] = arguments.pop(arg_name)
            # Aggregate uses "batchSize", while find uses batch_size.
            elif arg_name == "batchSize" and operation == "aggregate":
                continue
            # Requires boolean returnDocument.
            elif arg_name == "returnDocument":
                arguments[c2s] = arguments[arg_name] == "After"
            else:
                arguments[c2s] = arguments.pop(arg_name)

    result = cmd(**arguments)

    if operation == "aggregate":
        if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
            out = collection.database[arguments["pipeline"][-1]["$out"]]
            return out.find()
    return result
Ejemplo n.º 6
0
    def run_scenario(self, scenario_def, test):
        self.maybe_skip_scenario(test)

        # Kill all sessions before and after each test to prevent an open
        # transaction (from a test failure) from blocking collection/database
        # operations during test set up and tear down.
        self.kill_all_sessions()
        self.addCleanup(self.kill_all_sessions)
        self.setup_scenario(scenario_def)
        database_name = self.get_scenario_db_name(scenario_def)
        collection_name = self.get_scenario_coll_name(scenario_def)
        # SPEC-1245 workaround StaleDbVersion on distinct
        for c in self.mongos_clients:
            c[database_name][collection_name].distinct("x")

        # Configure the fail point before creating the client.
        if 'failPoint' in test:
            fp = test['failPoint']
            self.set_fail_point(fp)
            self.addCleanup(self.set_fail_point, {
                'configureFailPoint': fp['configureFailPoint'],
                'mode': 'off'
            })

        listener = OvertCommandListener()
        pool_listener = CMAPListener()
        server_listener = ServerAndTopologyEventListener()
        # Create a new client, to avoid interference from pooled sessions.
        client_options = self.parse_client_options(test['clientOptions'])
        # MMAPv1 does not support retryable writes.
        if (client_options.get('retryWrites') is True
                and client_context.storage_engine == 'mmapv1'):
            self.skipTest("MMAPv1 does not support retryWrites=True")
        use_multi_mongos = test['useMultipleMongoses']
        if client_context.is_mongos and use_multi_mongos:
            client = rs_client(
                client_context.mongos_seeds(),
                event_listeners=[listener, pool_listener, server_listener],
                **client_options)
        else:
            client = rs_client(
                event_listeners=[listener, pool_listener, server_listener],
                **client_options)
        self.scenario_client = client
        self.listener = listener
        self.pool_listener = pool_listener
        self.server_listener = server_listener
        # Close the client explicitly to avoid having too many threads open.
        self.addCleanup(client.close)

        # Create session0 and session1.
        sessions = {}
        session_ids = {}
        for i in range(2):
            # Don't attempt to create sessions if they are not supported by
            # the running server version.
            if not client_context.sessions_enabled:
                break
            session_name = 'session%d' % i
            opts = camel_to_snake_args(test['sessionOptions'][session_name])
            if 'default_transaction_options' in opts:
                txn_opts = self.parse_options(
                    opts['default_transaction_options'])
                txn_opts = client_session.TransactionOptions(**txn_opts)
                opts['default_transaction_options'] = txn_opts

            s = client.start_session(**dict(opts))

            sessions[session_name] = s
            # Store lsid so we can access it after end_session, in check_events.
            session_ids[session_name] = s.session_id

        self.addCleanup(end_sessions, sessions)

        collection = client[database_name][collection_name]
        self.run_test_ops(sessions, collection, test)

        end_sessions(sessions)

        self.check_events(test, listener, session_ids)

        # Disable fail points.
        if 'failPoint' in test:
            fp = test['failPoint']
            self.set_fail_point({
                'configureFailPoint': fp['configureFailPoint'],
                'mode': 'off'
            })

        # Assert final state is expected.
        outcome = test['outcome']
        expected_c = outcome.get('collection')
        if expected_c is not None:
            outcome_coll_name = self.get_outcome_coll_name(outcome, collection)

            # Read from the primary with local read concern to ensure causal
            # consistency.
            outcome_coll = client_context.client[
                collection.database.name].get_collection(
                    outcome_coll_name,
                    read_preference=ReadPreference.PRIMARY,
                    read_concern=ReadConcern('local'))
            actual_data = list(outcome_coll.find(sort=[('_id', 1)]))

            # The expected data needs to be the left hand side here otherwise
            # CompareType(Binary) doesn't work.
            self.assertEqual(wrap_types(expected_c['data']), actual_data)
Ejemplo n.º 7
0
    def run_operation(self, sessions, collection, operation):
        original_collection = collection
        name = camel_to_snake(operation['name'])
        if name == 'run_command':
            name = 'command'
        elif name == 'download_by_name':
            name = 'open_download_stream_by_name'
        elif name == 'download':
            name = 'open_download_stream'

        database = collection.database
        collection = database.get_collection(collection.name)
        if 'collectionOptions' in operation:
            collection = collection.with_options(
                **self.parse_options(operation['collectionOptions']))

        object_name = self.get_object_name(operation)
        if object_name == 'gridfsbucket':
            # Only create the GridFSBucket when we need it (for the gridfs
            # retryable reads tests).
            obj = GridFSBucket(database,
                               bucket_name=collection.name,
                               disable_md5=True)
        else:
            objects = {
                'client': database.client,
                'database': database,
                'collection': collection,
                'testRunner': self
            }
            objects.update(sessions)
            obj = objects[object_name]

        # Combine arguments with options and handle special cases.
        arguments = operation.get('arguments', {})
        arguments.update(arguments.pop("options", {}))
        self.parse_options(arguments)

        cmd = getattr(obj, name)

        for arg_name in list(arguments):
            c2s = camel_to_snake(arg_name)
            # PyMongo accepts sort as list of tuples.
            if arg_name == "sort":
                sort_dict = arguments[arg_name]
                arguments[arg_name] = list(iteritems(sort_dict))
            # Named "key" instead not fieldName.
            if arg_name == "fieldName":
                arguments["key"] = arguments.pop(arg_name)
            # Aggregate uses "batchSize", while find uses batch_size.
            elif ((arg_name == "batchSize" or arg_name == "allowDiskUse")
                  and name == "aggregate"):
                continue
            # Requires boolean returnDocument.
            elif arg_name == "returnDocument":
                arguments[c2s] = arguments.pop(arg_name) == "After"
            elif c2s == "requests":
                # Parse each request into a bulk write model.
                requests = []
                for request in arguments["requests"]:
                    bulk_model = camel_to_upper_camel(request["name"])
                    bulk_class = getattr(operations, bulk_model)
                    bulk_arguments = camel_to_snake_args(request["arguments"])
                    requests.append(bulk_class(**dict(bulk_arguments)))
                arguments["requests"] = requests
            elif arg_name == "session":
                arguments['session'] = sessions[arguments['session']]
            elif (name in ('command', 'run_admin_command')
                  and arg_name == 'command'):
                # Ensure the first key is the command name.
                ordered_command = SON([(operation['command_name'], 1)])
                ordered_command.update(arguments['command'])
                arguments['command'] = ordered_command
            elif name == 'open_download_stream' and arg_name == 'id':
                arguments['file_id'] = arguments.pop(arg_name)
            elif name != 'find' and c2s == 'max_time_ms':
                # find is the only method that accepts snake_case max_time_ms.
                # All other methods take kwargs which must use the server's
                # camelCase maxTimeMS. See PYTHON-1855.
                arguments['maxTimeMS'] = arguments.pop('max_time_ms')
            elif name == 'with_transaction' and arg_name == 'callback':
                callback_ops = arguments[arg_name]['operations']
                arguments['callback'] = lambda _: self.run_operations(
                    sessions,
                    original_collection,
                    copy.deepcopy(callback_ops),
                    in_with_transaction=True)
            elif name == 'drop_collection' and arg_name == 'collection':
                arguments['name_or_collection'] = arguments.pop(arg_name)
            elif name == 'create_collection' and arg_name == 'collection':
                arguments['name'] = arguments.pop(arg_name)
            elif name == 'create_index' and arg_name == 'keys':
                arguments['keys'] = list(arguments.pop(arg_name).items())
            elif name == 'drop_index' and arg_name == 'name':
                arguments['index_or_name'] = arguments.pop(arg_name)
            else:
                arguments[c2s] = arguments.pop(arg_name)

        if name == 'run_on_thread':
            args = {'sessions': sessions, 'collection': collection}
            args.update(arguments)
            arguments = args
        result = cmd(**dict(arguments))

        if name == "aggregate":
            if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
                # Read from the primary to ensure causal consistency.
                out = collection.database.get_collection(
                    arguments["pipeline"][-1]["$out"],
                    read_preference=ReadPreference.PRIMARY)
                return out.find()
        if name == "map_reduce":
            if isinstance(result, dict) and 'results' in result:
                return result['results']
        if 'download' in name:
            result = Binary(result.read())

        if isinstance(result, Cursor) or isinstance(result, CommandCursor):
            return list(result)

        return result
Ejemplo n.º 8
0
    def run_scenario(self, scenario_def, test):
        self.maybe_skip_scenario(test)
        listener = OvertCommandListener()
        # Create a new client, to avoid interference from pooled sessions.
        # Convert test['clientOptions'] to dict to avoid a Jython bug using
        # "**" with ScenarioDict.
        client_options = dict(test['clientOptions'])
        use_multi_mongos = test['useMultipleMongoses']
        if client_context.is_mongos and use_multi_mongos:
            client = rs_client(client_context.mongos_seeds(),
                               event_listeners=[listener],
                               **client_options)
        else:
            client = rs_client(event_listeners=[listener], **client_options)
        # Close the client explicitly to avoid having too many threads open.
        self.addCleanup(client.close)

        # Kill all sessions before and after each test to prevent an open
        # transaction (from a test failure) from blocking collection/database
        # operations during test set up and tear down.
        self.kill_all_sessions()
        self.addCleanup(self.kill_all_sessions)

        database_name = scenario_def['database_name']
        write_concern_db = client_context.client.get_database(
            database_name, write_concern=WriteConcern(w='majority'))
        if 'bucket_name' in scenario_def:
            # Create a bucket for the retryable reads GridFS tests.
            collection_name = scenario_def['bucket_name']
            client_context.client.drop_database(database_name)
            if scenario_def['data']:
                data = scenario_def['data']
                # Load data.
                write_concern_db['fs.chunks'].insert_many(data['fs.chunks'])
                write_concern_db['fs.files'].insert_many(data['fs.files'])
        else:
            collection_name = scenario_def['collection_name']
            write_concern_coll = write_concern_db[collection_name]
            write_concern_coll.drop()
            write_concern_db.create_collection(collection_name)
            if scenario_def['data']:
                # Load data.
                write_concern_coll.insert_many(scenario_def['data'])

        # SPEC-1245 workaround StaleDbVersion on distinct
        for c in self.mongos_clients:
            c[database_name][collection_name].distinct("x")

        # Create session0 and session1.
        sessions = {}
        session_ids = {}
        for i in range(2):
            session_name = 'session%d' % i
            opts = camel_to_snake_args(test['sessionOptions'][session_name])
            if 'default_transaction_options' in opts:
                txn_opts = opts['default_transaction_options']
                if 'readConcern' in txn_opts:
                    read_concern = ReadConcern(**dict(txn_opts['readConcern']))
                else:
                    read_concern = None
                if 'writeConcern' in txn_opts:
                    write_concern = WriteConcern(
                        **dict(txn_opts['writeConcern']))
                else:
                    write_concern = None

                if 'readPreference' in txn_opts:
                    read_pref = parse_read_preference(
                        txn_opts['readPreference'])
                else:
                    read_pref = None

                txn_opts = client_session.TransactionOptions(
                    read_concern=read_concern,
                    write_concern=write_concern,
                    read_preference=read_pref,
                )
                opts['default_transaction_options'] = txn_opts

            s = client.start_session(**dict(opts))

            sessions[session_name] = s
            # Store lsid so we can access it after end_session, in check_events.
            session_ids[session_name] = s.session_id

        self.addCleanup(end_sessions, sessions)

        if 'failPoint' in test:
            self.set_fail_point(test['failPoint'])
            self.addCleanup(self.set_fail_point, {
                'configureFailPoint': 'failCommand',
                'mode': 'off'
            })

        listener.results.clear()

        collection = client[database_name][collection_name]
        self.run_operations(sessions, collection, test['operations'])

        end_sessions(sessions)

        self.check_events(test, listener, session_ids)

        # Disable fail points.
        if 'failPoint' in test:
            self.set_fail_point({
                'configureFailPoint': 'failCommand',
                'mode': 'off'
            })

        # Assert final state is expected.
        expected_c = test['outcome'].get('collection')
        if expected_c is not None:
            # Read from the primary with local read concern to ensure causal
            # consistency.
            primary_coll = collection.with_options(
                read_preference=ReadPreference.PRIMARY,
                read_concern=ReadConcern('local'))
            self.assertEqual(list(primary_coll.find()), expected_c['data'])
    def run_scenario(self):
        if test.get('skipReason'):
            raise unittest.SkipTest(test.get('skipReason'))

        listener = OvertCommandListener()
        # Create a new client, to avoid interference from pooled sessions.
        # Convert test['clientOptions'] to dict to avoid a Jython bug using
        # "**" with ScenarioDict.
        client_options = dict(test['clientOptions'])
        use_multi_mongos = test['useMultipleMongoses']
        if client_context.is_mongos and use_multi_mongos:
            client = rs_client(client_context.mongos_seeds(),
                               event_listeners=[listener],
                               **client_options)
        else:
            client = rs_client(event_listeners=[listener], **client_options)
        # Close the client explicitly to avoid having too many threads open.
        self.addCleanup(client.close)

        # Kill all sessions before and after each test to prevent an open
        # transaction (from a test failure) from blocking collection/database
        # operations during test set up and tear down.
        self.kill_all_sessions()
        self.addCleanup(self.kill_all_sessions)

        database_name = scenario_def['database_name']
        collection_name = scenario_def['collection_name']
        write_concern_db = client.get_database(
            database_name, write_concern=WriteConcern(w='majority'))
        write_concern_coll = write_concern_db[collection_name]
        write_concern_coll.drop()
        write_concern_db.create_collection(collection_name)
        if scenario_def['data']:
            # Load data.
            write_concern_coll.insert_many(scenario_def['data'])

        # Create session0 and session1.
        sessions = {}
        session_ids = {}
        for i in range(2):
            session_name = 'session%d' % i
            opts = camel_to_snake_args(test['sessionOptions'][session_name])
            if 'default_transaction_options' in opts:
                txn_opts = opts['default_transaction_options']
                if 'readConcern' in txn_opts:
                    read_concern = ReadConcern(**dict(txn_opts['readConcern']))
                else:
                    read_concern = None
                if 'writeConcern' in txn_opts:
                    write_concern = WriteConcern(
                        **dict(txn_opts['writeConcern']))
                else:
                    write_concern = None

                if 'readPreference' in txn_opts:
                    read_pref = parse_read_preference(
                        txn_opts['readPreference'])
                else:
                    read_pref = None

                txn_opts = client_session.TransactionOptions(
                    read_concern=read_concern,
                    write_concern=write_concern,
                    read_preference=read_pref,
                )
                opts['default_transaction_options'] = txn_opts

            s = client.start_session(**dict(opts))

            sessions[session_name] = s
            # Store lsid so we can access it after end_session, in check_events.
            session_ids[session_name] = s.session_id

        self.addCleanup(end_sessions, sessions)

        if 'failPoint' in test:
            self.set_fail_point(test['failPoint'])
            self.addCleanup(self.set_fail_point, {
                'configureFailPoint': 'failCommand',
                'mode': 'off'
            })

        listener.results.clear()
        collection = client[database_name][collection_name]

        for op in test['operations']:
            expected_result = op.get('result')
            if expect_error(expected_result):
                with self.assertRaises(PyMongoError,
                                       msg=op['name']) as context:
                    self.run_operation(sessions, collection, op.copy())

                if expect_error_message(expected_result):
                    self.assertIn(expected_result['errorContains'].lower(),
                                  str(context.exception).lower())
                if expect_error_code(expected_result):
                    self.assertEqual(expected_result['errorCodeName'],
                                     context.exception.details.get('codeName'))
                if expect_error_labels_contain(expected_result):
                    self.assertErrorLabelsContain(
                        context.exception,
                        expected_result['errorLabelsContain'])
                if expect_error_labels_omit(expected_result):
                    self.assertErrorLabelsOmit(
                        context.exception, expected_result['errorLabelsOmit'])
            else:
                result = self.run_operation(sessions, collection, op.copy())
                if 'result' in op:
                    if op['name'] == 'runCommand':
                        self.check_command_result(expected_result, result)
                    else:
                        self.check_result(expected_result, result)

        for s in sessions.values():
            s.end_session()

        self.check_events(test, listener, session_ids)

        # Disable fail points.
        self.set_fail_point({
            'configureFailPoint': 'failCommand',
            'mode': 'off'
        })

        # Assert final state is expected.
        expected_c = test['outcome'].get('collection')
        if expected_c is not None:
            # Read from the primary with local read concern to ensure causal
            # consistency.
            primary_coll = collection.with_options(
                read_preference=ReadPreference.PRIMARY,
                read_concern=ReadConcern('local'))
            self.assertEqual(list(primary_coll.find()), expected_c['data'])
Ejemplo n.º 10
0
    def run_operation(self, sessions, collection, operation):
        name = camel_to_snake(operation['name'])
        if name == 'run_command':
            name = 'command'
        self.transaction_test_debug(name)

        def parse_options(opts):
            if 'readPreference' in opts:
                opts['read_preference'] = parse_read_preference(
                    opts.pop('readPreference'))

            if 'writeConcern' in opts:
                opts['write_concern'] = WriteConcern(
                    **dict(opts.pop('writeConcern')))

            if 'readConcern' in opts:
                opts['read_concern'] = ReadConcern(
                    **dict(opts.pop('readConcern')))
            return opts

        database = collection.database
        collection = database.get_collection(collection.name)
        if 'collectionOptions' in operation:
            collection = collection.with_options(
                **dict(parse_options(operation['collectionOptions'])))

        objects = {
            'database': database,
            'collection': collection,
            'testRunner': self
        }
        objects.update(sessions)
        obj = objects[operation['object']]

        # Combine arguments with options and handle special cases.
        arguments = operation.get('arguments', {})
        arguments.update(arguments.pop("options", {}))
        parse_options(arguments)

        cmd = getattr(obj, name)

        for arg_name in list(arguments):
            c2s = camel_to_snake(arg_name)
            # PyMongo accepts sort as list of tuples.
            if arg_name == "sort":
                sort_dict = arguments[arg_name]
                arguments[arg_name] = list(iteritems(sort_dict))
            # Named "key" instead not fieldName.
            if arg_name == "fieldName":
                arguments["key"] = arguments.pop(arg_name)
            # Aggregate uses "batchSize", while find uses batch_size.
            elif arg_name == "batchSize" and name == "aggregate":
                continue
            # Requires boolean returnDocument.
            elif arg_name == "returnDocument":
                arguments[c2s] = arguments[arg_name] == "After"
            elif c2s == "requests":
                # Parse each request into a bulk write model.
                requests = []
                for request in arguments["requests"]:
                    bulk_model = camel_to_upper_camel(request["name"])
                    bulk_class = getattr(operations, bulk_model)
                    bulk_arguments = camel_to_snake_args(request["arguments"])
                    requests.append(bulk_class(**dict(bulk_arguments)))
                arguments["requests"] = requests
            elif arg_name == "session":
                arguments['session'] = sessions[arguments['session']]
            elif name == 'command' and arg_name == 'command':
                # Ensure the first key is the command name.
                ordered_command = SON([(operation['command_name'], 1)])
                ordered_command.update(arguments['command'])
                arguments['command'] = ordered_command
            else:
                arguments[c2s] = arguments.pop(arg_name)

        result = cmd(**dict(arguments))

        if name == "aggregate":
            if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
                # Read from the primary to ensure causal consistency.
                out = collection.database.get_collection(
                    arguments["pipeline"][-1]["$out"],
                    read_preference=ReadPreference.PRIMARY)
                return out.find()

        if isinstance(result, Cursor) or isinstance(result, CommandCursor):
            return list(result)

        return result
    def run_scenario(self):
        if test.get('skipReason'):
            raise unittest.SkipTest(test.get('skipReason'))

        listener = OvertCommandListener()
        # Create a new client, to avoid interference from pooled sessions.
        # Convert test['clientOptions'] to dict to avoid a Jython bug using
        # "**" with ScenarioDict.
        client_options = dict(test['clientOptions'])
        use_multi_mongos = test['useMultipleMongoses']
        if client_context.is_mongos and use_multi_mongos:
            client = rs_client(client_context.mongos_seeds(),
                               event_listeners=[listener], **client_options)
        else:
            client = rs_client(event_listeners=[listener], **client_options)
        # Close the client explicitly to avoid having too many threads open.
        self.addCleanup(client.close)

        # Kill all sessions before and after each test to prevent an open
        # transaction (from a test failure) from blocking collection/database
        # operations during test set up and tear down.
        self.kill_all_sessions()
        self.addCleanup(self.kill_all_sessions)

        database_name = scenario_def['database_name']
        collection_name = scenario_def['collection_name']
        # Don't use the test client to load data.
        write_concern_db = client_context.client.get_database(
            database_name, write_concern=WriteConcern(w='majority'))
        write_concern_coll = write_concern_db[collection_name]
        write_concern_coll.drop()
        write_concern_db.create_collection(collection_name)
        if scenario_def['data']:
            # Load data.
            write_concern_coll.insert_many(scenario_def['data'])

        # SPEC-1245 workaround StaleDbVersion on distinct
        for c in self.mongos_clients:
            c[database_name][collection_name].distinct("x")

        # Create session0 and session1.
        sessions = {}
        session_ids = {}
        for i in range(2):
            session_name = 'session%d' % i
            opts = camel_to_snake_args(test['sessionOptions'][session_name])
            if 'default_transaction_options' in opts:
                txn_opts = opts['default_transaction_options']
                if 'readConcern' in txn_opts:
                    read_concern = ReadConcern(
                        **dict(txn_opts['readConcern']))
                else:
                    read_concern = None
                if 'writeConcern' in txn_opts:
                    write_concern = WriteConcern(
                        **dict(txn_opts['writeConcern']))
                else:
                    write_concern = None

                if 'readPreference' in txn_opts:
                    read_pref = parse_read_preference(
                        txn_opts['readPreference'])
                else:
                    read_pref = None

                txn_opts = client_session.TransactionOptions(
                    read_concern=read_concern,
                    write_concern=write_concern,
                    read_preference=read_pref,
                )
                opts['default_transaction_options'] = txn_opts

            s = client.start_session(**dict(opts))

            sessions[session_name] = s
            # Store lsid so we can access it after end_session, in check_events.
            session_ids[session_name] = s.session_id

        self.addCleanup(end_sessions, sessions)

        if 'failPoint' in test:
            self.set_fail_point(test['failPoint'])
            self.addCleanup(self.set_fail_point, {
                'configureFailPoint': 'failCommand', 'mode': 'off'})

        listener.results.clear()
        collection = client[database_name][collection_name]

        self.run_operations(sessions, collection, test['operations'])

        for s in sessions.values():
            s.end_session()

        self.check_events(test, listener, session_ids)

        # Disable fail points.
        self.set_fail_point({
            'configureFailPoint': 'failCommand', 'mode': 'off'})

        # Assert final state is expected.
        expected_c = test['outcome'].get('collection')
        if expected_c is not None:
            # Read from the primary with local read concern to ensure causal
            # consistency.
            primary_coll = collection.with_options(
                read_preference=ReadPreference.PRIMARY,
                read_concern=ReadConcern('local'))
            self.assertEqual(list(primary_coll.find()), expected_c['data'])
    def run_operation(self, sessions, collection, operation):
        original_collection = collection
        name = camel_to_snake(operation['name'])
        if name == 'run_command':
            name = 'command'
        self.transaction_test_debug(name)

        def parse_options(opts):
            if 'readPreference' in opts:
                opts['read_preference'] = parse_read_preference(
                    opts.pop('readPreference'))

            if 'writeConcern' in opts:
                opts['write_concern'] = WriteConcern(
                    **dict(opts.pop('writeConcern')))

            if 'readConcern' in opts:
                opts['read_concern'] = ReadConcern(
                    **dict(opts.pop('readConcern')))
            return opts

        database = collection.database
        collection = database.get_collection(collection.name)
        if 'collectionOptions' in operation:
            collection = collection.with_options(
                **dict(parse_options(operation['collectionOptions'])))

        objects = {
            'database': database,
            'collection': collection,
            'testRunner': self
        }
        objects.update(sessions)
        obj = objects[operation['object']]

        # Combine arguments with options and handle special cases.
        arguments = operation.get('arguments', {})
        arguments.update(arguments.pop("options", {}))
        parse_options(arguments)

        cmd = getattr(obj, name)

        for arg_name in list(arguments):
            c2s = camel_to_snake(arg_name)
            # PyMongo accepts sort as list of tuples.
            if arg_name == "sort":
                sort_dict = arguments[arg_name]
                arguments[arg_name] = list(iteritems(sort_dict))
            # Named "key" instead not fieldName.
            if arg_name == "fieldName":
                arguments["key"] = arguments.pop(arg_name)
            # Aggregate uses "batchSize", while find uses batch_size.
            elif arg_name == "batchSize" and name == "aggregate":
                continue
            # Requires boolean returnDocument.
            elif arg_name == "returnDocument":
                arguments[c2s] = arguments[arg_name] == "After"
            elif c2s == "requests":
                # Parse each request into a bulk write model.
                requests = []
                for request in arguments["requests"]:
                    bulk_model = camel_to_upper_camel(request["name"])
                    bulk_class = getattr(operations, bulk_model)
                    bulk_arguments = camel_to_snake_args(request["arguments"])
                    requests.append(bulk_class(**dict(bulk_arguments)))
                arguments["requests"] = requests
            elif arg_name == "session":
                arguments['session'] = sessions[arguments['session']]
            elif name == 'command' and arg_name == 'command':
                # Ensure the first key is the command name.
                ordered_command = SON([(operation['command_name'], 1)])
                ordered_command.update(arguments['command'])
                arguments['command'] = ordered_command
            elif name == 'with_transaction' and arg_name == 'callback':
                callback_ops = arguments[arg_name]['operations']
                arguments['callback'] = lambda _: self.run_operations(
                    sessions, original_collection, copy.deepcopy(callback_ops),
                    in_with_transaction=True)
            else:
                arguments[c2s] = arguments.pop(arg_name)

        result = cmd(**dict(arguments))

        if name == "aggregate":
            if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
                # Read from the primary to ensure causal consistency.
                out = collection.database.get_collection(
                    arguments["pipeline"][-1]["$out"],
                    read_preference=ReadPreference.PRIMARY)
                return out.find()

        if isinstance(result, Cursor) or isinstance(result, CommandCursor):
            return list(result)

        return result
Ejemplo n.º 13
0
    def _create_entity(self, entity_spec):
        if len(entity_spec) != 1:
            self._test_class.fail(
                "Entity spec %s did not contain exactly one top-level key" %
                (entity_spec, ))

        entity_type, spec = next(iteritems(entity_spec))
        if entity_type == 'client':
            kwargs = {}
            observe_events = spec.get('observeEvents', [])
            ignore_commands = spec.get('ignoreCommandMonitoringEvents', [])
            if len(observe_events) or len(ignore_commands):
                ignore_commands = [cmd.lower() for cmd in ignore_commands]
                listener = EventListenerUtil(observe_events, ignore_commands)
                self._listeners[spec['id']] = listener
                kwargs['event_listeners'] = [listener]
            if client_context.is_mongos and spec.get('useMultipleMongoses'):
                kwargs['h'] = client_context.mongos_seeds()
            kwargs.update(spec.get('uriOptions', {}))
            server_api = spec.get('serverApi')
            if server_api:
                kwargs['server_api'] = ServerApi(
                    server_api['version'],
                    strict=server_api.get('strict'),
                    deprecation_errors=server_api.get('deprecationErrors'))
            client = rs_or_single_client(**kwargs)
            self[spec['id']] = client
            self._test_class.addCleanup(client.close)
            return
        elif entity_type == 'database':
            client = self[spec['client']]
            if not isinstance(client, MongoClient):
                self._test_class.fail(
                    'Expected entity %s to be of type MongoClient, got %s' %
                    (spec['client'], type(client)))
            options = parse_collection_or_database_options(
                spec.get('databaseOptions', {}))
            self[spec['id']] = client.get_database(spec['databaseName'],
                                                   **options)
            return
        elif entity_type == 'collection':
            database = self[spec['database']]
            if not isinstance(database, Database):
                self._test_class.fail(
                    'Expected entity %s to be of type Database, got %s' %
                    (spec['database'], type(database)))
            options = parse_collection_or_database_options(
                spec.get('collectionOptions', {}))
            self[spec['id']] = database.get_collection(spec['collectionName'],
                                                       **options)
            return
        elif entity_type == 'session':
            client = self[spec['client']]
            if not isinstance(client, MongoClient):
                self._test_class.fail(
                    'Expected entity %s to be of type MongoClient, got %s' %
                    (spec['client'], type(client)))
            opts = camel_to_snake_args(spec.get('sessionOptions', {}))
            if 'default_transaction_options' in opts:
                txn_opts = parse_spec_options(
                    opts['default_transaction_options'])
                txn_opts = TransactionOptions(**txn_opts)
                opts = copy.deepcopy(opts)
                opts['default_transaction_options'] = txn_opts
            session = client.start_session(**dict(opts))
            self[spec['id']] = session
            self._session_lsids[spec['id']] = copy.deepcopy(session.session_id)
            self._test_class.addCleanup(session.end_session)
            return
        elif entity_type == 'bucket':
            # TODO: implement the 'bucket' entity type
            self._test_class.skipTest(
                'GridFS is not currently supported (PYTHON-2459)')
        self._test_class.fail('Unable to create entity of unknown type %s' %
                              (entity_type, ))