def _parametrize_test(self, test, generic_cls, device_cls): if device_cls is None: raise RuntimeError( 'The @modules decorator is only intended to be used in a device-specific ' 'context; use it with instantiate_device_type_tests() instead of ' 'instantiate_parametrized_tests()') for module_info in self.module_info_list: # Construct the test name; device / dtype parts are handled outside. # See [Note: device and dtype suffix placement] test_name = module_info.name.replace('.', '_') dtypes = set(module_info.dtypes) if self.allowed_dtypes is not None: dtypes = dtypes.intersection(self.allowed_dtypes) for dtype in dtypes: # Construct parameter kwargs to pass to the test. param_kwargs = {'module_info': module_info} _update_param_kwargs(param_kwargs, 'dtype', dtype) try: active_decorators = [set_single_threaded_if_parallel_tbb] if module_info.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype): active_decorators.append(skipIf(True, "Skipped!")) if module_info.decorators is not None: for decorator in module_info.decorators: # Can't use isinstance as it would cause a circular import if decorator.__class__.__name__ == 'DecorateInfo': if decorator.is_active(generic_cls.__name__, test.__name__, device_cls.device_type, dtype): active_decorators += decorator.decorators else: active_decorators.append(decorator) @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) for decorator in active_decorators: test_wrapper = decorator(test_wrapper) yield (test_wrapper, test_name, param_kwargs) except Exception as ex: # Provides an error message for debugging before rethrowing the exception print("Failed to instantiate {0} for module {1}!".format( test_name, module_info.name)) raise ex
def _parametrize_test(self, test, generic_cls, device_cls): for module_info in self.module_info_list: # TODO: Factor some of this out since it's similar to OpInfo. dtypes = set(module_info.dtypes) if self.allowed_dtypes is not None: dtypes = dtypes.intersection(self.allowed_dtypes) for dtype in dtypes: # Construct the test name. test_name = '{}_{}{}'.format( module_info.name.replace('.', '_'), device_cls.device_type, _dtype_test_suffix(dtype)) # Construct parameter kwargs to pass to the test. param_kwargs = {'module_info': module_info} _update_param_kwargs(param_kwargs, 'dtype', dtype) try: active_decorators = [set_single_threaded_if_parallel_tbb] if module_info.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype): active_decorators.append(skipIf(True, "Skipped!")) if module_info.decorators is not None: for decorator in module_info.decorators: # Can't use isinstance as it would cause a circular import if decorator.__class__.__name__ == 'DecorateInfo': if decorator.is_active(generic_cls.__name__, test.__name__, device_cls.device_type, dtype): active_decorators += decorator.decorators else: active_decorators.append(decorator) @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) for decorator in active_decorators: test_wrapper = decorator(test_wrapper) yield (test_wrapper, test_name, param_kwargs) except Exception as ex: # Provides an error message for debugging before rethrowing the exception print("Failed to instantiate {0} for module {1}!".format( test_name, module_info.name)) raise ex