Example #1
    def test_out(self, device, dtype, op):
        # TODO: verify the op doesn't support the out= kwarg
        if not op.supports_out:
            self.skipTest("Skipped! Op doesn't support out= kwarg.")

        # NOTE: only tests on first sample
        samples = op.sample_inputs(device, dtype)
        sample = samples[0]

        # calls it normally to get the expected result
        expected = op(sample.input, *sample.args, **sample.kwargs)
        op_out = partial(op, sample.input, *sample.args, **sample.kwargs)

        # Short-circuits if output is not a single tensor or an
        #   iterable of tensors

        # Returns True if iterable is an iterable of tensors (includes empty iterables)
        #   and False o.w.
        def _is_iterable_of_tensors(iterable):
                for t in iter(iterable):
                    if not isinstance(t, torch.Tensor):
                        return False
            except TypeError as te:
                return False

            return True

        if not isinstance(
                torch.Tensor) and not _is_iterable_of_tensors(expected):
                "Skipped! Only supports single tensor or iterable of tensor outputs."

        # A wrapper around map that works with single tensors and always
        #   instantiates the map. Used below to apply transforms to
        #   single tensor and iterable tensor outputs.
        def _apply_out_transform(fn, out):
            if isinstance(out, torch.Tensor):
                return fn(out)

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(fn, out))

        # Case 0: out= with the correct shape, dtype, and device
        #   but NaN values for floating point and complex tensors, and
        #   maximum values for integer tensors.
        #   Expected behavior: out= values have no effect on the computation.
        def _case_zero_transform(t):
                info = torch.iinfo(t.dtype)
                return torch.full_like(t, info.max)
            except TypeError as te:
                # for non-integer types fills with NaN
                return torch.full_like(t, float('nan'))

        out = _apply_out_transform(_case_zero_transform, expected)
        result = op_out(out=out)
        self.assertEqual(expected, out)

        # Checks that the returned value shares storage with out
        # NOTE: only checks on the CPU and CUDA device types since some
        #   device types don't have storage
        if self.device_type == 'cpu' or self.device_type == 'cuda':
            if isinstance(out, torch.Tensor):
                for out_t, result_t in zip(out, result):

        # Case 1: out= with the correct shape, dtype, and device,
        #   but noncontiguous.
        #   Expected behavior: strides are respected.
        def _case_one_transform(t):
            return make_tensor(t.shape,

        # Extracts strides from a tensor or iterable of tensors into a tuple
        def _extract_strides(out):
            if isinstance(out, torch.Tensor):
                return (out.stride(), )

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(lambda t: t.stride(), out))

        out = _apply_out_transform(_case_one_transform, expected)
        original_strides = _extract_strides(out)

        final_strides = _extract_strides(out)

        self.assertEqual(expected, out)
        self.assertEqual(original_strides, final_strides)

        # Case 2: out= with the correct dtype and device, but the wrong shape
        #   Expected behavior: resize with a warning.
        def _case_two_transform(t):
            wrong_shape = list(t.shape)

            if len(wrong_shape) == 0:
                # Handles scalar tensor case (empty list)
                wrong_shape = [2]
                wrong_shape[-1] = wrong_shape[-1] + 1
            return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)

        out = _apply_out_transform(_case_two_transform, expected)
        with self.assertWarnsRegex(UserWarning,
                                   "An output with one or more elements"):
        self.assertEqual(expected, out)

        # Case 3: out= with the correct dtype and device, but an empty
        #   tensor.
        #   Expected behavior: resize without warning.
        def _case_three_transform(t):
            return make_tensor((0, ), dtype=t.dtype, device=t.device)

        out = _apply_out_transform(_case_three_transform, expected)
        with warnings.catch_warnings(record=True) as caught:

        # Verifies no warning is a resize warning
        for w in caught:
            if "An output with one or more elements" in str(w.message):
                    "Resizing an out= argument with no elements threw a resize warning!"

        self.assertEqual(expected, out)

        # Case 4: out= with correct shape and dtype, but wrong device.
        wrong_device = None
        if torch.device(device).type != 'cpu':
            wrong_device = 'cpu'
        elif torch.cuda.is_available():
            wrong_device = 'cuda'

        if wrong_device is not None:

            def _case_four_transform(t):
                return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)

            out = _apply_out_transform(_case_four_transform, expected)
            with self.assertRaises(RuntimeError):

        # Case 5: out= with correct shape and device, but a dtype
        #   that output cannot be "safely" cast to (long).
        #   Expected behavior: error.
        # NOTE: this case is filtered by dtype since some ops produce
        #   bool tensors, for example, which can be safely cast to any
        #   dtype. It is applied when single tensors are floating point or complex
        #   dtypes, or if an op returns multiple tensors when at least one such
        #   tensor is a floating point or complex dtype.
        _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
        if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or
            (not isinstance(expected, torch.Tensor) and reduce(
                lambda cur, t: cur or t.dtype in _dtypes, expected, False))):

            def _case_five_transform(t):
                return make_tensor(t.shape, dtype=torch.long, device=t.device)

            out = out = _apply_out_transform(_case_five_transform, expected)
            with self.assertRaises(RuntimeError):
Example #2
    def test_dtypes(self, device, dtype, op):
        # dtypes to try to backward in
        allowed_backward_dtypes = floating_and_complex_types_and(
            torch.bfloat16, torch.float16)

        # lists for (un)supported dtypes
        supported_dtypes = []
        unsupported_dtypes = []
        supported_backward_dtypes = []
        unsupported_backward_dtypes = []

        def unsupported(dtype):
            if dtype in allowed_backward_dtypes:

        for dtype in get_all_dtypes():
            # tries to acquire samples - failure indicates lack of support
            requires_grad = (dtype in allowed_backward_dtypes
                             and op.supports_autograd)
                samples = op.sample_inputs(device,
            except Exception as e:

            # Counts number of successful backward attempts
            # NOTE: This exists as a kludge because this only understands how to
            #   request a gradient if the output is a tensor or a sequence with
            #   a tensor as its first element.
            num_backward_successes = 0
            for sample in samples:
                # tries to call operator with the sample - failure indicates
                #   lack of support
                    result = op(sample.input, *sample.args, **sample.kwargs)
                except Exception as e:
                    # NOTE: some ops will fail in forward if their inputs
                    #   require grad but they don't support computing the gradient
                    #   in that type! This is a bug in the op!

                # Short-circuits testing this dtype -- it doesn't work
                if dtype in unsupported_dtypes:

                # Short-circuits if the dtype isn't a backward dtype or
                #   it's already identified as not supported
                if dtype not in allowed_backward_dtypes or dtype in unsupported_backward_dtypes:

                # Checks for backward support in the same dtype
                    result = sample.output_process_fn_grad(result)
                    if isinstance(result, torch.Tensor):
                        backward_tensor = result
                    elif isinstance(result, Sequence) and isinstance(
                            result[0], torch.Tensor):
                        backward_tensor = result[0]

                    # Note: this grad may not have the same dtype as dtype
                    # For functions like complex (float -> complex) or abs
                    #   (complex -> float) the grad tensor will have a
                    #   different dtype than the input.
                    #   For simplicity, this is still modeled as these ops
                    #   supporting grad in the input dtype.
                    grad = torch.randn_like(backward_tensor)
                    num_backward_successes += 1
                except Exception as e:

            if dtype not in unsupported_dtypes:
            if num_backward_successes > 0 and dtype not in unsupported_backward_dtypes:

        # Checks that dtypes are listed correctly and generates an informative
        #   error message
        device_type = torch.device(device).type
        claimed_supported = set(op.supported_dtypes(device_type))
        supported_dtypes = set(supported_dtypes)

        supported_but_unclaimed = supported_dtypes - claimed_supported
        claimed_but_unsupported = claimed_supported - supported_dtypes
        msg = """The supported dtypes for {0} on {1} according to its OpInfo are
        {2}, but the detected supported dtypes are {3}.
        """.format(op.name, device_type, claimed_supported, supported_dtypes)

        if len(supported_but_unclaimed) > 0:
            msg += "The following dtypes should be added to the OpInfo: {0}. ".format(
        if len(claimed_but_unsupported) > 0:
            msg += "The following dtypes should be removed from the OpInfo: {0}.".format(

        self.assertEqual(supported_dtypes, claimed_supported, msg=msg)

        # Checks that backward dtypes are listed correctly and generates an
        #   informative error message
        # NOTE: this code is nearly identical to the check + msg generation
        claimed_backward_supported = set(
        supported_backward_dtypes = set(supported_backward_dtypes)

        supported_but_unclaimed = supported_backward_dtypes - claimed_backward_supported
        claimed_but_unsupported = claimed_backward_supported - supported_backward_dtypes
        msg = """The supported backward dtypes for {0} on {1} according to its OpInfo are
        {2}, but the detected supported backward dtypes are {3}.
        """.format(op.name, device_type, claimed_backward_supported,

        if len(supported_but_unclaimed) > 0:
            msg += "The following backward dtypes should be added to the OpInfo: {0}. ".format(
        if len(claimed_but_unsupported) > 0:
            msg += "The following backward dtypes should be removed from the OpInfo: {0}.".format(

Example #3
    def test_variant_consistency_jit(self, device, dtype, op):
        samples = op.sample_inputs(device, dtype, requires_grad=True)
        if len(samples) == 0:
            self.skipTest("Skipped! No sample inputs!")

        for sample in samples:

            # Acquires variants to test
            func = op.get_op()
            method = op.get_method()
            inplace = op.get_inplace()
            variants = {
                'function': func, 'method': method,
                # TODO: inplace tests currently fail
                # 'inplace': inplace,

            # Test traced and scripted consistency
            for func_type, variant in variants.items():
                if variant is None:

                # Create accessor for script function variant
                name = op.name + '_' if func_type == 'inplace' else op.name

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():
                    def fn(*inputs, **kwargs):
                        output = func(*inputs, **kwargs)
                        return op.output_func(output)

                    # bfloat16 grad doesn't work for some operators
                    dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \
                        if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16)

                    # Check scripted forward, grad, and grad grad
                    script_fn = create_script_fn(self, name, func_type, op.output_func)

                                            (*sample.input,) + sample.args,
                                            no_grad=(dtype not in dtypes_to_grad_check))

                    # Check traced forward, grad, and grad grad
                    traced_fn = create_traced_fn(self, variant)
                                            (*sample.input,) + sample.args,
                                            no_grad=(dtype not in dtypes_to_grad_check))

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 and int64 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype in [torch.float32, torch.int32]:
                        check_alias_annotation(name, (*sample.input,) + sample.args, sample.kwargs,
                                               func_type=func_type, aten_name=op.aten_name)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
                        self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
Example #4
class TestOpInfo(TestCase):
    exact_dtype = True

    # Verifies that ops have their unsupported dtypes
    #   registered correctly by testing that each claimed unsupported dtype
    #   throws a runtime error
    @ops(op_db, dtypes=OpDTypes.unsupported)
    def test_unsupported_dtypes(self, device, dtype, op):
        # sample_inputs can have a function for generating the input that doesn't work for specified dtype
        # https://github.com/pytorch/pytorch/issues/49024
        with self.assertRaises(RuntimeError):
            samples = op.sample_inputs(device, dtype)
            if len(samples) == 0:
                self.skipTest("Skipped! No sample inputs!")

            # NOTE: only tests on first sample
            sample = samples[0]
            op(sample.input, *sample.args, **sample.kwargs)

    # Verifies that ops have their supported dtypes
    #   registered correctly by testing that each claimed supported dtype
    #   does NOT throw a runtime error
    # In addition verifies that the generated sample_inputs have the requested device and dtype
    @ops(op_db, dtypes=OpDTypes.supported)
    def test_supported_dtypes(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype):
            op(sample.input, *sample.args, **sample.kwargs)
            # NOTE: only check the first tensor in the iterable of tensors
            sample_input = sample.input[0] if is_iterable_of_tensors(
                sample.input) else sample.input
            self.assertTrue(sample_input.dtype == dtype)
            self.assertTrue(sample_input.device.type == self.device_type)

    # Verifies that backward for each supported floating or complex dtype
    #   does NOT throw a runtime error.
    # TODO: support multi-tensor outputs
    def test_supported_backward(self, device, dtype, op):
        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
        if not op.supports_complex_autograd and dtype.is_complex:
            self.skipTest("Skipped! Complex autograd not supported.")

        for sample in op.sample_inputs(device, dtype, requires_grad=True):
            result = op(sample.input, *sample.args, **sample.kwargs)
            if not isinstance(result, torch.Tensor):


    # Verifies that ops do not have an entry in
    # `method_tests` (legacy testing infra).
    @ops(op_db, allowed_dtypes=[torch.float32])
    def test_duplicate_method_tests(self, device, dtype, op):
        self.assertFalse(op.name in method_tested_operators)