def is_run_on_requirement_satisfied(requirement): topology_satisfied = True req_topologies = requirement.get('topologies') if req_topologies: topology_satisfied = client_context.is_topology_type(req_topologies) min_version_satisfied = True req_min_server_version = requirement.get('minServerVersion') if req_min_server_version: min_version_satisfied = Version.from_string( req_min_server_version) <= client_context.version max_version_satisfied = True req_max_server_version = requirement.get('maxServerVersion') if req_max_server_version: max_version_satisfied = Version.from_string( req_max_server_version) >= client_context.version params_satisfied = True params = requirement.get('serverParameters') if params: for param, val in params.items(): if param not in client_context.server_parameters: params_satisfied = False elif client_context.server_parameters[param] != val: params_satisfied = False return (topology_satisfied and min_version_satisfied and max_version_satisfied and params_satisfied)
def is_run_on_requirement_satisfied(requirement): topology_satisfied = True req_topologies = requirement.get('topologies') if req_topologies: topology_satisfied = client_context.is_topology_type(req_topologies) server_version = Version(*client_context.version[:3]) min_version_satisfied = True req_min_server_version = requirement.get('minServerVersion') if req_min_server_version: min_version_satisfied = Version.from_string( req_min_server_version) <= server_version max_version_satisfied = True req_max_server_version = requirement.get('maxServerVersion') if req_max_server_version: max_version_satisfied = Version.from_string( req_max_server_version) >= server_version serverless = requirement.get('serverless') if serverless == "require": serverless_satisfied = client_context.serverless elif serverless == "forbid": serverless_satisfied = not client_context.serverless else: # unset or "allow" serverless_satisfied = True params_satisfied = True params = requirement.get('serverParameters') if params: for param, val in params.items(): if param not in client_context.server_parameters: params_satisfied = False elif client_context.server_parameters[param] != val: params_satisfied = False auth_satisfied = True req_auth = requirement.get('auth') if req_auth is not None: if req_auth: auth_satisfied = client_context.auth_enabled else: auth_satisfied = not client_context.auth_enabled return (topology_satisfied and min_version_satisfied and max_version_satisfied and serverless_satisfied and params_satisfied and auth_satisfied)
def generate_test_classes(test_path, module=__name__, class_name_prefix='', expected_failures=[], bypass_test_generation_errors=False, **kwargs): """Method for generating test classes. Returns a dictionary where keys are the names of test classes and values are the test class objects.""" test_klasses = {} def test_base_class_factory(test_spec): """Utility that creates the base class to use for test generation. This is needed to ensure that cls.TEST_SPEC is appropriately set when the metaclass __init__ is invoked.""" class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): TEST_SPEC = test_spec EXPECTED_FAILURES = expected_failures return SpecTestBase for dirpath, _, filenames in os.walk(test_path): dirname = os.path.split(dirpath)[-1] for filename in filenames: fpath = os.path.join(dirpath, filename) with open(fpath) as scenario_stream: # Use tz_aware=False to match how CodecOptions decodes # dates. opts = json_util.JSONOptions(tz_aware=False) scenario_def = json_util.loads(scenario_stream.read(), json_options=opts) test_type = os.path.splitext(filename)[0] snake_class_name = 'Test%s_%s_%s' % ( class_name_prefix, dirname.replace( '-', '_'), test_type.replace('-', '_').replace('.', '_')) class_name = snake_to_camel(snake_class_name) try: schema_version = Version.from_string( scenario_def['schemaVersion']) mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get( schema_version[0]) if mixin_class is None: raise ValueError( "test file '%s' has unsupported schemaVersion '%s'" % (fpath, schema_version)) module_dict = {'__module__': module} module_dict.update(kwargs) test_klasses[class_name] = type(class_name, ( mixin_class, test_base_class_factory(scenario_def), ), module_dict) except Exception: if bypass_test_generation_errors: continue raise return test_klasses
def is_run_on_requirement_satisfied(requirement): topology_satisfied = True req_topologies = requirement.get('topologies') if req_topologies: topology_satisfied = client_context.is_topology_type(req_topologies) min_version_satisfied = True req_min_server_version = requirement.get('minServerVersion') if req_min_server_version: min_version_satisfied = Version.from_string( req_min_server_version) <= client_context.version max_version_satisfied = True req_max_server_version = requirement.get('maxServerVersion') if req_max_server_version: max_version_satisfied = Version.from_string( req_max_server_version) >= client_context.version return (topology_satisfied and min_version_satisfied and max_version_satisfied)
def setUp(self): super(UnifiedSpecTestMixinV1, self).setUp() # process schemaVersion # note: we check major schema version during class generation # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC['schemaVersion']) self.assertLessEqual( version, self.SCHEMA_VERSION, 'expected schema version %s or lower, got %s' % (self.SCHEMA_VERSION, version)) # initialize internals self.match_evaluator = MatchEvaluatorUtil(self)
class UnifiedSpecTestMixinV1(IntegrationTest): """Mixin class to run test cases from test specification files. Assumes that tests conform to the `unified test format <https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst>`_. Specification of the test suite being currently run is available as a class attribute ``TEST_SPEC``. """ SCHEMA_VERSION = Version.from_string('1.1') @staticmethod def should_run_on(run_on_spec): if not run_on_spec: # Always run these tests. return True for req in run_on_spec: if is_run_on_requirement_satisfied(req): return True return False def insert_initial_data(self, initial_data): for collection_data in initial_data: coll_name = collection_data['collectionName'] db_name = collection_data['databaseName'] documents = collection_data['documents'] coll = self.client.get_database(db_name).get_collection( coll_name, write_concern=WriteConcern(w="majority")) coll.drop() if len(documents) > 0: coll.insert_many(documents) else: # ensure collection exists result = coll.insert_one({}) coll.delete_one({'_id': result.inserted_id}) @classmethod def setUpClass(cls): # super call creates internal client cls.client super(UnifiedSpecTestMixinV1, cls).setUpClass() # process file-level runOnRequirements run_on_spec = cls.TEST_SPEC.get('runOnRequirements', []) if not cls.should_run_on(run_on_spec): raise unittest.SkipTest('%s runOnRequirements not satisfied' % (cls.__name__, )) # add any special-casing for skipping tests here if client_context.storage_engine == 'mmapv1': if 'retryable-writes' in cls.TEST_SPEC['description']: raise unittest.SkipTest( "MMAPv1 does not support retryWrites=True") @classmethod def tearDownClass(cls): super(UnifiedSpecTestMixinV1, cls).tearDownClass() cls.client.close() def setUp(self): super(UnifiedSpecTestMixinV1, self).setUp() # process schemaVersion # note: we check major schema version during class generation # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC['schemaVersion']) self.assertLessEqual( version, self.SCHEMA_VERSION, 'expected schema version %s or lower, got %s' % (self.SCHEMA_VERSION, version)) # initialize internals self.match_evaluator = MatchEvaluatorUtil(self) def maybe_skip_test(self, spec): # add any special-casing for skipping tests here if client_context.storage_engine == 'mmapv1': if 'Dirty explicit session is discarded' in spec['description']: raise unittest.SkipTest( "MMAPv1 does not support retryWrites=True") def process_error(self, exception, spec): is_error = spec.get('isError') is_client_error = spec.get('isClientError') error_contains = spec.get('errorContains') error_code = spec.get('errorCode') error_code_name = spec.get('errorCodeName') error_labels_contain = spec.get('errorLabelsContain') error_labels_omit = spec.get('errorLabelsOmit') expect_result = spec.get('expectResult') if is_error: # already satisfied because exception was raised pass if is_client_error: self.assertNotIsInstance(exception, PyMongoError) if error_contains: if isinstance(exception, BulkWriteError): errmsg = str(exception.details).lower() else: errmsg = str(exception).lower() self.assertIn(error_contains.lower(), errmsg) if error_code: self.assertEqual(error_code, exception.details.get('code')) if error_code_name: self.assertEqual(error_code_name, exception.details.get('codeName')) if error_labels_contain: labels = [ err_label for err_label in error_labels_contain if exception.has_error_label(err_label) ] self.assertEqual(labels, error_labels_contain) if error_labels_omit: for err_label in error_labels_omit: if exception.has_error_label(err_label): self.fail("Exception '%s' unexpectedly had label '%s'" % (exception, err_label)) if expect_result: if isinstance(exception, BulkWriteError): result = parse_bulk_write_error_result(exception) self.match_evaluator.match_result(expect_result, result) else: self.fail("expectResult can only be specified with %s " "exceptions" % (BulkWriteError, )) def __raise_if_unsupported(self, opname, target, *target_types): if not isinstance(target, target_types): self.fail('Operation %s not supported for entity ' 'of type %s' % (opname, type(target))) def __entityOperation_createChangeStream(self, target, *args, **kwargs): if client_context.storage_engine == 'mmapv1': self.skipTest("MMAPv1 does not support change streams") self.__raise_if_unsupported('createChangeStream', target, MongoClient, Database, Collection) return target.watch(*args, **kwargs) def _clientOperation_createChangeStream(self, target, *args, **kwargs): return self.__entityOperation_createChangeStream( target, *args, **kwargs) def _databaseOperation_createChangeStream(self, target, *args, **kwargs): return self.__entityOperation_createChangeStream( target, *args, **kwargs) def _collectionOperation_createChangeStream(self, target, *args, **kwargs): return self.__entityOperation_createChangeStream( target, *args, **kwargs) def _databaseOperation_runCommand(self, target, **kwargs): self.__raise_if_unsupported('runCommand', target, Database) # Ensure the first key is the command name. ordered_command = SON([(kwargs.pop('command_name'), 1)]) ordered_command.update(kwargs['command']) kwargs['command'] = ordered_command return target.command(**kwargs) def __entityOperation_aggregate(self, target, *args, **kwargs): self.__raise_if_unsupported('aggregate', target, Database, Collection) return list(target.aggregate(*args, **kwargs)) def _databaseOperation_aggregate(self, target, *args, **kwargs): return self.__entityOperation_aggregate(target, *args, **kwargs) def _collectionOperation_aggregate(self, target, *args, **kwargs): return self.__entityOperation_aggregate(target, *args, **kwargs) def _collectionOperation_bulkWrite(self, target, *args, **kwargs): self.__raise_if_unsupported('bulkWrite', target, Collection) write_result = target.bulk_write(*args, **kwargs) return parse_bulk_write_result(write_result) def _collectionOperation_find(self, target, *args, **kwargs): self.__raise_if_unsupported('find', target, Collection) find_cursor = target.find(*args, **kwargs) return list(find_cursor) def _collectionOperation_findOneAndReplace(self, target, *args, **kwargs): self.__raise_if_unsupported('findOneAndReplace', target, Collection) return target.find_one_and_replace(*args, **kwargs) def _collectionOperation_findOneAndUpdate(self, target, *args, **kwargs): self.__raise_if_unsupported('findOneAndReplace', target, Collection) return target.find_one_and_update(*args, **kwargs) def _collectionOperation_insertMany(self, target, *args, **kwargs): self.__raise_if_unsupported('insertMany', target, Collection) result = target.insert_many(*args, **kwargs) return {idx: _id for idx, _id in enumerate(result.inserted_ids)} def _collectionOperation_insertOne(self, target, *args, **kwargs): self.__raise_if_unsupported('insertOne', target, Collection) result = target.insert_one(*args, **kwargs) return {'insertedId': result.inserted_id} def _sessionOperation_withTransaction(self, target, *args, **kwargs): if client_context.storage_engine == 'mmapv1': self.skipTest('MMAPv1 does not support document-level locking') self.__raise_if_unsupported('withTransaction', target, ClientSession) return target.with_transaction(*args, **kwargs) def _sessionOperation_startTransaction(self, target, *args, **kwargs): if client_context.storage_engine == 'mmapv1': self.skipTest('MMAPv1 does not support document-level locking') self.__raise_if_unsupported('startTransaction', target, ClientSession) return target.start_transaction(*args, **kwargs) def _changeStreamOperation_iterateUntilDocumentOrError( self, target, *args, **kwargs): self.__raise_if_unsupported('iterateUntilDocumentOrError', target, ChangeStream) return next(target) def run_entity_operation(self, spec): target = self.entity_map[spec['object']] opname = spec['name'] opargs = spec.get('arguments') expect_error = spec.get('expectError') if opargs: arguments = parse_spec_options(copy.deepcopy(opargs)) prepare_spec_arguments(spec, arguments, camel_to_snake(opname), self.entity_map, self.run_operations) else: arguments = tuple() if isinstance(target, MongoClient): method_name = '_clientOperation_%s' % (opname, ) elif isinstance(target, Database): method_name = '_databaseOperation_%s' % (opname, ) elif isinstance(target, Collection): method_name = '_collectionOperation_%s' % (opname, ) elif isinstance(target, ChangeStream): method_name = '_changeStreamOperation_%s' % (opname, ) elif isinstance(target, ClientSession): method_name = '_sessionOperation_%s' % (opname, ) elif isinstance(target, GridFSBucket): raise NotImplementedError else: method_name = 'doesNotExist' try: method = getattr(self, method_name) except AttributeError: try: cmd = getattr(target, camel_to_snake(opname)) except AttributeError: self.fail('Unsupported operation %s on entity %s' % (opname, target)) else: cmd = functools.partial(method, target) try: result = cmd(**dict(arguments)) except Exception as exc: if expect_error: return self.process_error(exc, expect_error) raise if 'expectResult' in spec: self.match_evaluator.match_result(spec['expectResult'], result) save_as_entity = spec.get('saveResultAsEntity') if save_as_entity: self.entity_map[save_as_entity] = result def __set_fail_point(self, client, command_args): if not client_context.test_commands_enabled: self.skipTest('Test commands must be enabled') cmd_on = SON([('configureFailPoint', 'failCommand')]) cmd_on.update(command_args) client.admin.command(cmd_on) self.addCleanup(client.admin.command, 'configureFailPoint', cmd_on['configureFailPoint'], mode='off') def _testOperation_failPoint(self, spec): self.__set_fail_point(client=self.entity_map[spec['client']], command_args=spec['failPoint']) def _testOperation_targetedFailPoint(self, spec): session = self.entity_map[spec['session']] if not session._pinned_address: self.fail("Cannot use targetedFailPoint operation with unpinned " "session %s" % (spec['session'], )) client = single_client('%s:%s' % session._pinned_address) self.__set_fail_point(client=client, command_args=spec['failPoint']) self.addCleanup(client.close) def _testOperation_assertSessionTransactionState(self, spec): session = self.entity_map[spec['session']] expected_state = getattr(_TxnState, spec['state'].upper()) self.assertEqual(expected_state, session._transaction.state) def _testOperation_assertSessionPinned(self, spec): session = self.entity_map[spec['session']] self.assertIsNotNone(session._pinned_address) def _testOperation_assertSessionUnpinned(self, spec): session = self.entity_map[spec['session']] self.assertIsNone(session._pinned_address) def __get_last_two_command_lsids(self, listener): cmd_started_events = [] for event in reversed(listener.results): if isinstance(event, CommandStartedEvent): cmd_started_events.append(event) if len(cmd_started_events) < 2: self.fail('Needed 2 CommandStartedEvents to compare lsids, ' 'got %s' % (len(cmd_started_events))) return tuple([e.command['lsid'] for e in cmd_started_events][:2]) def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec): listener = self.entity_map.get_listener_for_client(spec['client']) self.assertNotEqual(*self.__get_last_two_command_lsids(listener)) def _testOperation_assertSameLsidOnLastTwoCommands(self, spec): listener = self.entity_map.get_listener_for_client(spec['client']) self.assertEqual(*self.__get_last_two_command_lsids(listener)) def _testOperation_assertSessionDirty(self, spec): session = self.entity_map[spec['session']] self.assertTrue(session._server_session.dirty) def _testOperation_assertSessionNotDirty(self, spec): session = self.entity_map[spec['session']] return self.assertFalse(session._server_session.dirty) def _testOperation_assertCollectionExists(self, spec): database_name = spec['databaseName'] collection_name = spec['collectionName'] collection_name_list = list( self.client.get_database(database_name).list_collection_names()) self.assertIn(collection_name, collection_name_list) def _testOperation_assertCollectionNotExists(self, spec): database_name = spec['databaseName'] collection_name = spec['collectionName'] collection_name_list = list( self.client.get_database(database_name).list_collection_names()) self.assertNotIn(collection_name, collection_name_list) def _testOperation_assertIndexExists(self, spec): collection = self.client[spec['databaseName']][spec['collectionName']] index_names = [idx['name'] for idx in collection.list_indexes()] self.assertIn(spec['indexName'], index_names) def _testOperation_assertIndexNotExists(self, spec): collection = self.client[spec['databaseName']][spec['collectionName']] for index in collection.list_indexes(): self.assertNotEqual(spec['indexName'], index['name']) def run_special_operation(self, spec): opname = spec['name'] method_name = '_testOperation_%s' % (opname, ) try: method = getattr(self, method_name) except AttributeError: self.fail('Unsupported special test operation %s' % (opname, )) else: method(spec['arguments']) def run_operations(self, spec): for op in spec: target = op['object'] if target != 'testRunner': self.run_entity_operation(op) else: self.run_special_operation(op) def check_events(self, spec): for event_spec in spec: client_name = event_spec['client'] events = event_spec['events'] listener = self.entity_map.get_listener_for_client(client_name) if len(events) == 0: self.assertEqual(listener.results, []) continue if len(events) > len(listener.results): self.fail('Expected to see %s events, got %s' % (len(events), len(listener.results))) for idx, expected_event in enumerate(events): self.match_evaluator.match_event(expected_event, listener.results[idx]) def verify_outcome(self, spec): for collection_data in spec: coll_name = collection_data['collectionName'] db_name = collection_data['databaseName'] expected_documents = collection_data['documents'] coll = self.client.get_database(db_name).get_collection( coll_name, read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern(level='local')) if expected_documents: sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc['_id']) actual_documents = list( coll.find({}, sort=[('_id', ASCENDING)])) self.assertListEqual(sorted_expected_documents, actual_documents) def run_scenario(self, spec): # maybe skip test manually self.maybe_skip_test(spec) # process test-level runOnRequirements run_on_spec = spec.get('runOnRequirements', []) if not self.should_run_on(run_on_spec): raise unittest.SkipTest('runOnRequirements not satisfied') # process skipReason skip_reason = spec.get('skipReason', None) if skip_reason is not None: raise unittest.SkipTest('%s' % (skip_reason, )) # process createEntities self.entity_map = EntityMapUtil(self) self.entity_map.create_entities_from_spec( self.TEST_SPEC.get('createEntities', [])) # process initialData self.insert_initial_data(self.TEST_SPEC.get('initialData', [])) # process operations self.run_operations(spec['operations']) # process expectEvents self.check_events(spec.get('expectEvents', [])) # process outcome self.verify_outcome(spec.get('outcome', []))