def instantiate_test(cls, name, test, *, generic_cls=None):
        def instantiate_test_helper(cls, name, *, test, dtype, op):

            # Constructs the test's name
            test_name = _construct_test_name(name, op, cls.device_type, dtype)

            # wraps instantiated test with op decorators
            # NOTE: test_wrapper exists because we don't want to apply
            #   op-specific decorators to the original test.
            #   Test-sepcific decorators are applied to the original test,
            #   however.
            if op is not None and op.decorators is not None:

                @wraps(test)
                def test_wrapper(*args, **kwargs):
                    return test(*args, **kwargs)

                for decorator in op.decorators:
                    test_wrapper = decorator(test_wrapper)

                test_fn = test_wrapper
            else:
                test_fn = test

            # Constructs the test
            @wraps(test)
            def instantiated_test(self,
                                  name=name,
                                  test=test_fn,
                                  dtype=dtype,
                                  op=op):
                if op is not None and op.should_skip(
                        generic_cls.__name__, name, self.device_type, dtype):
                    self.skipTest("Skipped!")

                device_arg: str = cls.get_primary_device()
                if hasattr(test_fn, 'num_required_devices'):
                    device_arg = cls.get_all_devices()

                # Sets precision and runs test
                # Note: precision is reset after the test is run
                guard_precision = self.precision
                try:
                    self.precision = self._get_precision_override(
                        test_fn, dtype)
                    args = (arg for arg in (device_arg, dtype, op)
                            if arg is not None)
                    result = test_fn(self, *args)
                except RuntimeError as rte:
                    # check if rte should stop entire test suite.
                    self._stop_test_suite = self._should_stop_test_suite(rte)
                    # raise the runtime error as is for the test suite to record.
                    raise rte
                finally:
                    self.precision = guard_precision

                return result

            assert not hasattr(
                cls, test_name), "Redefinition of test {0}".format(test_name)
            setattr(cls, test_name, instantiated_test)

        # Handles tests using the ops decorator
        if hasattr(test, "op_list"):
            for op in test.op_list:
                # Acquires dtypes, using the op data if unspecified
                dtypes = cls._get_dtypes(test)
                if dtypes is None:
                    if test.opinfo_dtypes == OpDTypes.unsupported:
                        dtypes = set(get_all_dtypes()).difference(
                            op.supported_dtypes(cls.device_type))
                    elif test.opinfo_dtypes == OpDTypes.supported:
                        dtypes = op.supported_dtypes(cls.device_type)
                    elif test.opinfo_dtypes == OpDTypes.basic:
                        dtypes = op.default_test_dtypes(cls.device_type)
                    else:
                        raise RuntimeError(
                            f"Unknown OpDType: {test.opinfo_dtypes}")

                    if test.allowed_dtypes is not None:
                        dtypes = dtypes.intersection(test.allowed_dtypes)
                else:
                    assert test.allowed_dtypes is None, "ops(allowed_dtypes=[...]) and the dtypes decorator are incompatible"
                    assert test.opinfo_dtypes == OpDTypes.basic, "ops(dtypes=...) and the dtypes decorator are incompatible"

                for dtype in dtypes:
                    instantiate_test_helper(cls,
                                            name,
                                            test=test,
                                            dtype=dtype,
                                            op=op)
        else:
            # Handles tests that don't use the ops decorator
            dtypes = cls._get_dtypes(test)
            dtypes = tuple(dtypes) if dtypes is not None else (None, )
            for dtype in dtypes:
                instantiate_test_helper(cls,
                                        name,
                                        test=test,
                                        dtype=dtype,
                                        op=None)
Example #2
0
    def instantiate_test(cls, name, test, *, generic_cls=None):
        def instantiate_test_helper(cls, name, *, test, dtype, op):

            # Constructs the test's name
            test_name = _construct_test_name(name, op, cls.device_type, dtype)

            # Wraps instantiated test with op decorators
            # NOTE: test_wrapper exists because we don't want to apply
            #   op-specific decorators to the original test.
            #   Test-specific decorators are applied to the original test,
            #   however.
            if op is not None:
                try:
                    active_decorators = []
                    if op.should_skip(generic_cls.__name__, name,
                                      cls.device_type, dtype):
                        active_decorators.append(skipIf(True, "Skipped!"))

                    if op.decorators is not None:
                        for decorator in op.decorators:
                            # Can't use isinstance as it would cause a circular import
                            if decorator.__class__.__name__ == 'DecorateInfo':
                                if decorator.is_active(generic_cls.__name__,
                                                       name, cls.device_type,
                                                       dtype):
                                    active_decorators += decorator.decorators
                            else:
                                active_decorators.append(decorator)

                    @wraps(test)
                    def test_wrapper(*args, **kwargs):
                        return test(*args, **kwargs)

                    for decorator in active_decorators:
                        test_wrapper = decorator(test_wrapper)

                    test_fn = test_wrapper
                except Exception as ex:
                    # Provides an error message for debugging before rethrowing the exception
                    print("Failed to instantiate {0} for op {1}!".format(
                        test_name, op.name))
                    raise ex
            else:
                test_fn = test

            # Constructs the test
            @wraps(test_fn)
            def instantiated_test(self,
                                  name=name,
                                  test=test_fn,
                                  dtype=dtype,
                                  op=op):
                device_arg: str = cls.get_primary_device()
                if hasattr(test_fn, 'num_required_devices'):
                    device_arg = cls.get_all_devices()

                # Sets precision and runs test
                # Note: precision is reset after the test is run
                guard_precision = self.precision
                try:
                    self.precision = self._get_precision_override(
                        test_fn, dtype)
                    args = (arg for arg in (device_arg, dtype, op)
                            if arg is not None)
                    result = test_fn(self, *args)
                except RuntimeError as rte:
                    # check if rte should stop entire test suite.
                    self._stop_test_suite = self._should_stop_test_suite()
                    # raise the runtime error as is for the test suite to record.
                    raise rte
                finally:
                    self.precision = guard_precision

                return result

            assert not hasattr(
                cls, test_name), "Redefinition of test {0}".format(test_name)
            setattr(cls, test_name, instantiated_test)

        # Handles tests using the ops decorator
        if hasattr(test, "op_list"):
            for op in test.op_list:
                # Acquires dtypes, using the op data if unspecified
                dtypes = cls._get_dtypes(test)
                if dtypes is None:
                    if test.opinfo_dtypes == OpDTypes.unsupported_backward:
                        dtypes = set(get_all_dtypes()).difference(
                            op.supported_backward_dtypes(cls.device_type))
                    elif test.opinfo_dtypes == OpDTypes.supported_backward:
                        dtypes = op.supported_backward_dtypes(cls.device_type)
                    elif test.opinfo_dtypes == OpDTypes.unsupported:
                        dtypes = set(get_all_dtypes()).difference(
                            op.supported_dtypes(cls.device_type))
                    elif test.opinfo_dtypes == OpDTypes.supported:
                        dtypes = op.supported_dtypes(cls.device_type)
                    elif test.opinfo_dtypes == OpDTypes.basic:
                        dtypes = op.default_test_dtypes(cls.device_type)
                    elif test.opinfo_dtypes == OpDTypes.none:
                        dtypes = _NO_DTYPES
                    else:
                        raise RuntimeError(
                            f"Unknown OpDType: {test.opinfo_dtypes}")

                    if test.allowed_dtypes is not None:
                        dtypes = dtypes.intersection(test.allowed_dtypes)
                else:
                    assert test.allowed_dtypes is None, "ops(allowed_dtypes=[...]) and the dtypes decorator are incompatible"
                    assert test.opinfo_dtypes == OpDTypes.basic, "ops(dtypes=...) and the dtypes decorator are incompatible"

                if dtypes is _NO_DTYPES:
                    instantiate_test_helper(cls,
                                            name,
                                            test=test,
                                            dtype=_NO_DTYPES,
                                            op=op)
                else:
                    for dtype in dtypes:
                        instantiate_test_helper(cls,
                                                name,
                                                test=test,
                                                dtype=dtype,
                                                op=op)
        else:
            # Handles tests that don't use the ops decorator
            dtypes = cls._get_dtypes(test)
            dtypes = tuple(dtypes) if dtypes is not None else (None, )
            for dtype in dtypes:
                instantiate_test_helper(cls,
                                        name,
                                        test=test,
                                        dtype=dtype,
                                        op=None)
Example #3
0
import torch
from torch.testing import get_all_dtypes

dtypes = get_all_dtypes()
dtypes.remove(torch.half)

grad_dtypes = [torch.float, torch.double]

devices = [torch.device('cpu')]
if torch.cuda.is_available():
    devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]


def tensor(x, dtype, device):
    return None if x is None else torch.tensor(x, dtype=dtype, device=device)
Example #4
0
    def instantiate_test(cls, name, test, *, generic_cls=None):
        def instantiate_test_helper(cls, name, *, test, dtype, op):

            # Constructs the test's name
            test_name = _construct_test_name(name, op, cls.device_type, dtype)

            # wraps instantiated test with op decorators
            # NOTE: test_wrapper exists because we don't want to apply
            #   op-specific decorators to the original test.
            #   Test-sepcific decorators are applied to the original test,
            #   however.
            if op is not None and op.decorators is not None:

                @wraps(test)
                def test_wrapper(*args, **kwargs):
                    return test(*args, **kwargs)

                for decorator in op.decorators:
                    test_wrapper = decorator(test_wrapper)

                test_fn = test_wrapper
            else:
                test_fn = test

            # Constructs the test
            @wraps(test)
            def instantiated_test(self,
                                  name=name,
                                  test=test_fn,
                                  dtype=dtype,
                                  op=op):
                if op is not None and op.should_skip(
                        generic_cls.__name__, name, self.device_type, dtype):
                    self.skipTest("Skipped!")

                device_arg: str = cls.get_primary_device()
                if hasattr(test_fn, 'num_required_devices'):
                    device_arg = cls.get_all_devices()

                # Sets precision and runs test
                # Note: precision is reset after the test is run
                guard_precision = self.precision
                try:
                    self.precision = self._get_precision_override(
                        test_fn, dtype)
                    args = (arg for arg in (device_arg, dtype, op)
                            if arg is not None)
                    result = test_fn(self, *args)
                finally:
                    self.precision = guard_precision

                return result

            assert not hasattr(
                cls, test_name), "Redefinition of test {0}".format(test_name)
            setattr(cls, test_name, instantiated_test)

        # Handles tests using the ops decorator
        if hasattr(test, "op_list"):
            for op in test.op_list:
                # Acquires dtypes, using the op data if unspecified
                dtypes = cls._get_dtypes(test)
                if dtypes is None:
                    if cls.device_type == 'cpu' and op.dtypesIfCPU is not None:
                        dtypes = op.dtypesIfCPU
                    elif (cls.device_type == 'cuda' and not TEST_WITH_ROCM
                          and op.dtypesIfCUDA is not None):
                        dtypes = op.dtypesIfCUDA
                    elif (cls.device_type == 'cuda' and TEST_WITH_ROCM
                          and op.dtypesIfROCM is not None):
                        dtypes = op.dtypesIfROCM
                    else:
                        dtypes = op.dtypes

                # Inverts dtypes if the function wants unsupported dtypes
                if test.unsupported_dtypes_only is True:
                    dtypes = [d for d in get_all_dtypes() if d not in dtypes]
                dtypes = dtypes if dtypes is not None else (None, )
                for dtype in dtypes:
                    instantiate_test_helper(cls,
                                            name,
                                            test=test,
                                            dtype=dtype,
                                            op=op)
        else:
            # Handles tests that don't use the ops decorator
            dtypes = cls._get_dtypes(test)
            dtypes = tuple(dtypes) if dtypes is not None else (None, )
            for dtype in dtypes:
                instantiate_test_helper(cls,
                                        name,
                                        test=test,
                                        dtype=dtype,
                                        op=None)
Example #5
0
    def test_dtypes(self, device, op):
        # dtypes to try to backward in
        allowed_backward_dtypes = floating_and_complex_types_and(
            torch.bfloat16, torch.float16)

        # lists for (un)supported dtypes
        supported_dtypes = []
        unsupported_dtypes = []
        supported_backward_dtypes = []
        unsupported_backward_dtypes = []

        def unsupported(dtype):
            unsupported_dtypes.append(dtype)
            if dtype in allowed_backward_dtypes:
                unsupported_backward_dtypes.append(dtype)

        for dtype in get_all_dtypes():
            # tries to acquire samples - failure indicates lack of support
            requires_grad = (dtype in allowed_backward_dtypes
                             and op.supports_autograd)
            try:
                samples = op.sample_inputs(device,
                                           dtype,
                                           requires_grad=requires_grad)
            except Exception as e:
                unsupported(dtype)
                continue

            # Counts number of successful backward attempts
            # NOTE: This exists as a kludge because this only understands how to
            #   request a gradient if the output is a tensor or a sequence with
            #   a tensor as its first element.
            num_backward_successes = 0
            for sample in samples:
                # tries to call operator with the sample - failure indicates
                #   lack of support
                try:
                    result = op(sample.input, *sample.args, **sample.kwargs)
                except Exception as e:
                    # NOTE: some ops will fail in forward if their inputs
                    #   require grad but they don't support computing the gradient
                    #   in that type! This is a bug in the op!
                    unsupported(dtype)

                # Short-circuits testing this dtype -- it doesn't work
                if dtype in unsupported_dtypes:
                    break

                # Short-circuits if the dtype isn't a backward dtype or
                #   it's already identified as not supported
                if dtype not in allowed_backward_dtypes or dtype in unsupported_backward_dtypes:
                    continue

                # Checks for backward support in the same dtype
                try:
                    result = sample.output_process_fn_grad(result)
                    if isinstance(result, torch.Tensor):
                        backward_tensor = result
                    elif isinstance(result, Sequence) and isinstance(
                            result[0], torch.Tensor):
                        backward_tensor = result[0]
                    else:
                        continue

                    # Note: this grad may not have the same dtype as dtype
                    # For functions like complex (float -> complex) or abs
                    #   (complex -> float) the grad tensor will have a
                    #   different dtype than the input.
                    #   For simplicity, this is still modeled as these ops
                    #   supporting grad in the input dtype.
                    grad = torch.randn_like(backward_tensor)
                    backward_tensor.backward(grad)
                    num_backward_successes += 1
                except Exception as e:
                    unsupported_backward_dtypes.append(dtype)

            if dtype not in unsupported_dtypes:
                supported_dtypes.append(dtype)
            if num_backward_successes > 0 and dtype not in unsupported_backward_dtypes:
                supported_backward_dtypes.append(dtype)

        # Checks that dtypes are listed correctly and generates an informative
        #   error message
        device_type = torch.device(device).type
        claimed_supported = set(op.supported_dtypes(device_type))
        supported_dtypes = set(supported_dtypes)

        supported_but_unclaimed = supported_dtypes - claimed_supported
        claimed_but_unsupported = claimed_supported - supported_dtypes
        msg = """The supported dtypes for {0} on {1} according to its OpInfo are
        {2}, but the detected supported dtypes are {3}.
        """.format(op.name, device_type, claimed_supported, supported_dtypes)

        if len(supported_but_unclaimed) > 0:
            msg += "The following dtypes should be added to the OpInfo: {0}. ".format(
                supported_but_unclaimed)
        if len(claimed_but_unsupported) > 0:
            msg += "The following dtypes should be removed from the OpInfo: {0}.".format(
                claimed_but_unsupported)

        self.assertEqual(supported_dtypes, claimed_supported, msg=msg)

        # Checks that backward dtypes are listed correctly and generates an
        #   informative error message
        # NOTE: this code is nearly identical to the check + msg generation
        claimed_backward_supported = set(
            op.supported_backward_dtypes(device_type))
        supported_backward_dtypes = set(supported_backward_dtypes)

        supported_but_unclaimed = supported_backward_dtypes - claimed_backward_supported
        claimed_but_unsupported = claimed_backward_supported - supported_backward_dtypes
        msg = """The supported backward dtypes for {0} on {1} according to its OpInfo are
        {2}, but the detected supported backward dtypes are {3}.
        """.format(op.name, device_type, claimed_backward_supported,
                   supported_backward_dtypes)

        if len(supported_but_unclaimed) > 0:
            msg += "The following backward dtypes should be added to the OpInfo: {0}. ".format(
                supported_but_unclaimed)
        if len(claimed_but_unsupported) > 0:
            msg += "The following backward dtypes should be removed from the OpInfo: {0}.".format(
                claimed_but_unsupported)

        self.assertEqual(supported_backward_dtypes,
                         claimed_backward_supported,
                         msg=msg)
Example #6
0
    def instantiate_test(cls, name, test):
        def instantiate_test_helper(cls, name, *, test, dtype, op):

            # Constructs the test's name
            test_name = name
            if op is not None:
                test_name += "_" + op.name
            test_name += "_" + cls.device_type
            if dtype is not None:
                if isinstance(dtype, (list, tuple)):
                    for d in dtype:
                        test_name += "_" + str(d).split('.')[1]
                else:
                    test_name += "_" + str(dtype).split('.')[1]

            # Constructs the test
            @wraps(test)
            def instantiated_test(self, test=test, dtype=dtype, op=op):
                device_arg = cls.get_primary_device()
                if hasattr(test, 'num_required_devices'):
                    device_arg = cls.get_all_devices()

                # Sets precision and runs test
                # Note: precision is reset after the test is run
                guard_precision = self.precision
                try:
                    self.precision = self._get_precision_override(test, dtype)
                    args = (device_arg, dtype, op)
                    args = (arg for arg in args if arg is not None)
                    result = test(self, *args)
                finally:
                    self.precision = guard_precision

                return result

            # wraps with op decorators
            if op is not None and op.decorators is not None:
                for decorator in op.decorators:
                    instantiated_test = decorator(instantiated_test)

            assert not hasattr(
                cls, test_name), "Redefinition of test {0}".format(test_name)
            setattr(cls, test_name, instantiated_test)

        # Handles tests using the ops decorator
        if hasattr(test, "op_list"):
            for op in test.op_list:
                # Acquires dtypes, using the op data if unspecified
                dtypes = cls._get_dtypes(test)
                if dtypes is None:
                    if cls.device_type == 'cpu' and op.dtypesIfCPU is not None:
                        dtypes = op.dtypesIfCPU
                    elif (cls.device_type == 'cuda' and not TEST_WITH_ROCM
                          and op.dtypesIfCUDA is not None):
                        dtypes = op.dtypesIfCUDA
                    elif (cls.device_type == 'cuda' and TEST_WITH_ROCM
                          and op.dtypesIfROCM is not None):
                        dtypes = op.dtypesIfROCM
                    else:
                        dtypes = op.dtypes

                # Inverts dtypes if the function wants unsupported dtypes
                if test.unsupported_dtypes_only is True:
                    dtypes = [d for d in get_all_dtypes() if d not in dtypes]

                dtypes = dtypes if dtypes is not None else (None, )
                for dtype in dtypes:
                    instantiate_test_helper(cls,
                                            name,
                                            test=test,
                                            dtype=dtype,
                                            op=op)
        else:
            # Handles tests that don't use the ops decorator
            dtypes = cls._get_dtypes(test)
            dtypes = tuple(dtypes) if dtypes is not None else (None, )
            for dtype in dtypes:
                instantiate_test_helper(cls,
                                        name,
                                        test=test,
                                        dtype=dtype,
                                        op=None)