Exemple #1
0
    def _parametrize_test(self, test, generic_cls, device_cls):
        if device_cls is None:
            raise RuntimeError(
                'The @modules decorator is only intended to be used in a device-specific '
                'context; use it with instantiate_device_type_tests() instead of '
                'instantiate_parametrized_tests()')

        for module_info in self.module_info_list:
            # Construct the test name; device / dtype parts are handled outside.
            # See [Note: device and dtype suffix placement]
            test_name = module_info.name.replace('.', '_')

            dtypes = set(module_info.dtypes)
            if self.allowed_dtypes is not None:
                dtypes = dtypes.intersection(self.allowed_dtypes)

            for dtype in dtypes:
                # Construct parameter kwargs to pass to the test.
                param_kwargs = {'module_info': module_info}
                _update_param_kwargs(param_kwargs, 'dtype', dtype)

                try:
                    active_decorators = [set_single_threaded_if_parallel_tbb]
                    if module_info.should_skip(generic_cls.__name__,
                                               test.__name__,
                                               device_cls.device_type, dtype):
                        active_decorators.append(skipIf(True, "Skipped!"))

                    if module_info.decorators is not None:
                        for decorator in module_info.decorators:
                            # Can't use isinstance as it would cause a circular import
                            if decorator.__class__.__name__ == 'DecorateInfo':
                                if decorator.is_active(generic_cls.__name__,
                                                       test.__name__,
                                                       device_cls.device_type,
                                                       dtype):
                                    active_decorators += decorator.decorators
                            else:
                                active_decorators.append(decorator)

                    @wraps(test)
                    def test_wrapper(*args, **kwargs):
                        return test(*args, **kwargs)

                    for decorator in active_decorators:
                        test_wrapper = decorator(test_wrapper)

                    yield (test_wrapper, test_name, param_kwargs)
                except Exception as ex:
                    # Provides an error message for debugging before rethrowing the exception
                    print("Failed to instantiate {0} for module {1}!".format(
                        test_name, module_info.name))
                    raise ex
Exemple #2
0
    def _parametrize_test(self, test, generic_cls, device_cls):
        for module_info in self.module_info_list:
            # TODO: Factor some of this out since it's similar to OpInfo.
            dtypes = set(module_info.dtypes)
            if self.allowed_dtypes is not None:
                dtypes = dtypes.intersection(self.allowed_dtypes)

            for dtype in dtypes:
                # Construct the test name.
                test_name = '{}_{}{}'.format(
                    module_info.name.replace('.', '_'), device_cls.device_type,
                    _dtype_test_suffix(dtype))

                # Construct parameter kwargs to pass to the test.
                param_kwargs = {'module_info': module_info}
                _update_param_kwargs(param_kwargs, 'dtype', dtype)

                try:
                    active_decorators = [set_single_threaded_if_parallel_tbb]
                    if module_info.should_skip(generic_cls.__name__,
                                               test.__name__,
                                               device_cls.device_type, dtype):
                        active_decorators.append(skipIf(True, "Skipped!"))

                    if module_info.decorators is not None:
                        for decorator in module_info.decorators:
                            # Can't use isinstance as it would cause a circular import
                            if decorator.__class__.__name__ == 'DecorateInfo':
                                if decorator.is_active(generic_cls.__name__,
                                                       test.__name__,
                                                       device_cls.device_type,
                                                       dtype):
                                    active_decorators += decorator.decorators
                            else:
                                active_decorators.append(decorator)

                    @wraps(test)
                    def test_wrapper(*args, **kwargs):
                        return test(*args, **kwargs)

                    for decorator in active_decorators:
                        test_wrapper = decorator(test_wrapper)

                    yield (test_wrapper, test_name, param_kwargs)
                except Exception as ex:
                    # Provides an error message for debugging before rethrowing the exception
                    print("Failed to instantiate {0} for module {1}!".format(
                        test_name, module_info.name))
                    raise ex