Пример #1
0
def partial_apply_nontensors(fn, args, **kwargs):
    source = ['t' if (isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)) else 's' for arg in args]

    def new_fn(*tensors_):
        tensors = iter(tensors_)
        return fn(*(args[i] if s == 's' else next(tensors) for i, s in enumerate(source)), **kwargs)

    return new_fn, [arg for arg in args if isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)]
Пример #2
0
 def test_supported_dtypes(self, device, dtype, op):
     for sample in op.sample_inputs(device, dtype):
         op(sample.input, *sample.args, **sample.kwargs)
         # NOTE: only check the first tensor in the iterable of tensors
         sample_input = sample.input[0] if is_iterable_of_tensors(sample.input) else sample.input
         self.assertTrue(sample_input.dtype == dtype)
         self.assertTrue(sample_input.device.type == self.device_type)
Пример #3
0
 def fn(*inputs):
     # Pack input back into TensorList since we splat it when passing to gradcheck
     if is_iterable_of_tensors(sample.input):
         n = len(sample.input)
         inputs = (inputs[:n], *inputs[n:])
     output = variant_out_fn(*inputs, **sample.kwargs)
     return op.output_func(output)
Пример #4
0
 def fn(*inputs):
     # Put tensors back into TensorList since we splat them when passing to gradcheck
     if is_iterable_of_tensors(sample.input):
         n = len(sample.input)
         inputs = (inputs[:n], *inputs[n:])
     output = op.gradcheck_wrapper(variant, *inputs, **sample.kwargs)
     if sample.output_process_fn_grad is not None:
         return sample.output_process_fn_grad(output)
     return output
Пример #5
0
    def get_recording_tensors(args):
        recording_tensors: List[torch.Tensor] = []

        for arg in args:
            if isinstance(arg, torch.Tensor) and arg.requires_grad:
                recording_tensors.append(arg)
            elif is_iterable_of_tensors(arg):
                recording_tensors.extend(filter(lambda t: t.requires_grad,
                                                arg))

        return recording_tensors
Пример #6
0
def clone_inputs(args):
    inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []

    for arg in args:
        if isinstance(arg, torch.Tensor):
            inputs.append(arg.detach().clone())
        elif is_iterable_of_tensors(arg):
            inputs.append([t.detach().clone() for t in arg])
        else:
            inputs.append(arg)

    return inputs
Пример #7
0
    def clone_inputs(preserve_requires_grad: bool):
        inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []

        for arg in args:
            if isinstance(arg, torch.Tensor):
                inputs.append(clone_tensor(arg, preserve_requires_grad))
            elif is_iterable_of_tensors(arg):
                inputs.append(
                    [clone_tensor(t, preserve_requires_grad) for t in arg])
            else:
                inputs.append(arg)

        return inputs
Пример #8
0
 def _input_recomposition_helper(inputs, inp, input_idx):
     if is_iterable_of_tensors(inp):
         tensor_list = []
         for x in inp:
             if isinstance(x, torch.Tensor) and x.requires_grad:
                 tensor_list.append(inputs[input_idx])
                 input_idx = input_idx + 1
             else:
                 tensor_list.append(x)
         return tensor_list, input_idx
     elif isinstance(inp, torch.Tensor) and inp.requires_grad:
         return inputs[input_idx], input_idx + 1
     else:
         return inp, input_idx
Пример #9
0
def diff_arg(arg, requires_grad=True):
    def is_differentiable_arg(arg):
        if requires_grad:
            return arg.requires_grad
        else:
            return arg.is_floating_point() or arg.is_complex()

    if is_iterable_of_tensors(arg):
        if all([is_differentiable_arg(a) for a in arg]):
            return True
        if all([not is_differentiable_arg(a) for a in arg]):
            return False
        raise RuntimeError("NYI: The test runner can't handle this")
    return isinstance(arg, Tensor) and is_differentiable_arg(arg)
Пример #10
0
    def test_multiple_devices(self, devices, dtype, op):
        for cuda_device_str in devices:
            cuda_device = torch.device(cuda_device_str)
            # NOTE: only tests on first sample
            samples = op.sample_inputs(cuda_device, dtype)
            sample = samples[0]
            result = op(sample.input, *sample.args, **sample.kwargs)

            if isinstance(result, torch.Tensor):
                self.assertTrue(result.device == cuda_device)
            elif is_iterable_of_tensors(result):
                self.assertTrue(all(map(lambda t: t.device == cuda_device, result)))
            else:
                self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
Пример #11
0
def get_script_args(args):
    formals: List[str] = []
    tensors: List[Union[torch.Tensor, List[torch.Tensor]]] = []
    actuals: List[str] = []
    for arg in args:
        if isinstance(arg, torch.Tensor):
            name = 'i{}'.format(len(formals))
            formals.append(name)
            actuals.append(name)
            tensors.append(arg)
        elif is_iterable_of_tensors(arg):
            name = 'i{}'.format(len(formals))
            formals.append(name + ': List[torch.Tensor]')
            actuals.append(name)
            tensors.append(list(arg))
        elif isinstance(arg, str):
            actuals.append("'{}'".format(arg))
        else:
            actuals.append(str(get_constant(arg)))
    return (formals, tensors, actuals)
Пример #12
0
    def test_out(self, device, dtype, op):
        # TODO: verify the op doesn't support the out= kwarg
        if not op.supports_out:
            self.skipTest("Skipped! Op doesn't support out= kwarg.")

        # NOTE: only tests on first sample
        samples = op.sample_inputs(device, dtype)
        sample = samples[0]

        # calls it normally to get the expected result
        expected = op(sample.input, *sample.args, **sample.kwargs)
        op_out = partial(op, sample.input, *sample.args, **sample.kwargs)

        # Short-circuits if output is not a single tensor or an
        #   iterable of tensors

        if not isinstance(expected,
                          torch.Tensor) and not is_iterable_of_tensors(
                              expected, include_empty=True):
            self.skipTest(
                "Skipped! Only supports single tensor or iterable of tensor outputs."
            )

        # A wrapper around map that works with single tensors and always
        #   instantiates the map. Used below to apply transforms to
        #   single tensor and iterable tensor outputs.
        def _apply_out_transform(fn, out):
            if isinstance(out, torch.Tensor):
                return fn(out)

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(fn, out))

        # Case 0: out= with the correct shape, dtype, and device
        #   but NaN values for floating point and complex tensors, and
        #   maximum values for integer tensors.
        #   Expected behavior: out= values have no effect on the computation.
        def _case_zero_transform(t):
            try:
                info = torch.iinfo(t.dtype)
                return torch.full_like(t, info.max)
            except TypeError as te:
                # for non-integer types fills with NaN
                return torch.full_like(t, float('nan'))

        out = _apply_out_transform(_case_zero_transform, expected)
        result = op_out(out=out)
        self.assertEqual(expected, out)

        # Checks that the returned value shares storage with out
        # NOTE: only checks on the CPU and CUDA device types since some
        #   device types don't have storage
        if self.device_type == 'cpu' or self.device_type == 'cuda':
            if isinstance(out, torch.Tensor):
                self.assertEqual(out.storage().data_ptr(),
                                 result.storage().data_ptr())
            else:
                for out_t, result_t in zip(out, result):
                    self.assertEqual(out_t.storage().data_ptr(),
                                     result_t.storage().data_ptr())

        # Case 1: out= with the correct shape, dtype, and device,
        #   but noncontiguous.
        #   Expected behavior: strides are respected and `out` storage is not changed.
        def _case_one_transform(t):
            return make_tensor(t.shape,
                               dtype=t.dtype,
                               device=t.device,
                               noncontiguous=True)

        # Extracts strides from a tensor or iterable of tensors into a tuple
        def _extract_strides(out):
            if isinstance(out, torch.Tensor):
                return (out.stride(), )

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(lambda t: t.stride(), out))

        def _extract_data_ptrs(out):
            if isinstance(out, torch.Tensor):
                return (out.data_ptr(), )

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(lambda t: t.data_ptr(), out))

        out = _apply_out_transform(_case_one_transform, expected)
        original_strides = _extract_strides(out)
        original_ptrs = _extract_data_ptrs(out)

        op_out(out=out)
        final_strides = _extract_strides(out)
        final_ptrs = _extract_data_ptrs(out)

        self.assertEqual(expected, out)
        self.assertEqual(original_strides, final_strides)
        self.assertEqual(original_ptrs, final_ptrs)

        # Case 2: out= with the correct dtype and device, but the wrong shape
        #   Expected behavior: resize with a warning.
        def _case_two_transform(t):
            wrong_shape = list(t.shape)

            if len(wrong_shape) == 0:
                # Handles scalar tensor case (empty list)
                wrong_shape = [2]
            else:
                wrong_shape[-1] = wrong_shape[-1] + 1
            return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)

        out = _apply_out_transform(_case_two_transform, expected)
        msg_fail = "Resized a non-empty tensor but did not warn about it."
        with self.assertWarnsRegex(UserWarning,
                                   "An output with one or more elements",
                                   msg=msg_fail):
            op_out(out=out)
        self.assertEqual(expected, out)

        # Case 3: out= with the correct dtype and device, but an empty
        #   tensor.
        #   Expected behavior: resize without warning.
        def _case_three_transform(t):
            return make_tensor((0, ), dtype=t.dtype, device=t.device)

        out = _apply_out_transform(_case_three_transform, expected)
        with warnings.catch_warnings(record=True) as caught:
            warnings.simplefilter("always")
            op_out(out=out)

        # Verifies no warning is a resize warning
        for w in caught:
            if "An output with one or more elements" in str(w.message):
                self.fail(
                    "Resizing an out= argument with no elements threw a resize warning!"
                )

        self.assertEqual(expected, out)

        # Case 4: out= with correct shape and dtype, but wrong device.
        wrong_device = None
        if torch.device(device).type != 'cpu':
            wrong_device = 'cpu'
        elif torch.cuda.is_available():
            wrong_device = 'cuda'

        if wrong_device is not None:

            def _case_four_transform(t):
                return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)

            out = _apply_out_transform(_case_four_transform, expected)
            msg_fail = f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}"
            with self.assertRaises(RuntimeError, msg=msg_fail):
                op_out(out=out)

        # Case 5: out= with correct shape and device, but a dtype
        #   that output cannot be "safely" cast to (long).
        #   Expected behavior: error.
        # NOTE: this case is filtered by dtype since some ops produce
        #   bool tensors, for example, which can be safely cast to any
        #   dtype. It is applied when single tensors are floating point or complex
        #   dtypes, or if an op returns multiple tensors when at least one such
        #   tensor is a floating point or complex dtype.
        _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
        if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes
                or (not isinstance(expected, torch.Tensor)
                    and any(t.dtype in _dtypes for t in expected))):

            def _case_five_transform(t):
                return make_tensor(t.shape, dtype=torch.long, device=t.device)

            out = _apply_out_transform(_case_five_transform, expected)
            msg_fail = "" if not isinstance(expected, torch.Tensor) else \
                       ("Expected RuntimeError when doing an unsafe cast from a result of dtype "
                        f"{expected.dtype} into an out= with dtype torch.long")
            with self.assertRaises(RuntimeError, msg=msg_fail):
                op_out(out=out)
Пример #13
0
    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
                      check_batched_grad=None, check_batched_forward_grad=False):
        assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
        # NB: check_backward_ad does not affect gradgradcheck (always True)
        if variant is None:
            self.skipTest("Skipped! Variant not implemented.")
        if not op.supports_dtype(dtype, torch.device(device).type):
            self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")

        def is_inplace(variant):
            if hasattr(variant, "__wrapped__"):
                return variant.__wrapped__ is op.get_inplace()
            return variant is op.get_inplace()

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex

        samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
                                   small_inputs_only=is_slow_gradcheck_env())

        for sample in samples:
            if sample.broadcasts_input and is_inplace(variant):
                continue

            # Gradcheck expects tensors as its input, but autograd actually supports tensorlists
            #   and tensors passed as kwargs. The following creates a function that accepts just
            #   the tensors that require grad as varargs, and then recomposes them back into the
            #   original input.

            # Creates gradcheck inputs by identifying tensors requiring grad
            all_args = None
            if is_iterable_of_tensors(sample.input):
                all_args = chain(sample.input, sample.args, sample.kwargs.values())
            else:
                all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
            gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))

            def _input_recomposition_helper(inputs, inp, input_idx):
                if is_iterable_of_tensors(inp):
                    tensor_list = []
                    for x in inp:
                        if isinstance(x, torch.Tensor) and x.requires_grad:
                            tensor_list.append(inputs[input_idx])
                            input_idx = input_idx + 1
                        else:
                            tensor_list.append(x)
                    return tensor_list, input_idx
                elif isinstance(inp, torch.Tensor) and inp.requires_grad:
                    return inputs[input_idx], input_idx + 1
                else:
                    return inp, input_idx

            def fn(*inputs):
                # Puts inputs back into sample properly
                positional_args = []
                input_idx = 0
                inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
                positional_args.append(inp)

                for x in sample.args:
                    inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
                    positional_args.append(inp)

                # Recreates kwargs
                kwargs = {}
                for k, v in sample.kwargs.items():
                    inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
                    kwargs[k] = inp

                output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            if check == 'gradcheck':
                if check_batched_grad is None:
                    check_batched_grad = op.check_batched_grad
                self.assertTrue(gradcheck(fn, gradcheck_args,
                                          check_batched_grad=check_batched_grad,
                                          check_grad_dtypes=True,
                                          nondet_tol=op.gradcheck_nondet_tol,
                                          fast_mode=op.gradcheck_fast_mode,
                                          check_forward_ad=check_forward_ad,
                                          check_backward_ad=check_backward_ad,
                                          check_undefined_grad=True,
                                          check_batched_forward_grad=check_batched_forward_grad))
            elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'):  # gradgrad check
                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
                for gen_non_contig_grad_outputs in (False, True):
                    kwargs = {
                        "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
                        "check_batched_grad": op.check_batched_gradgrad,
                        "check_grad_dtypes": True,
                        "nondet_tol": op.gradcheck_nondet_tol,
                        "fast_mode": op.gradcheck_fast_mode
                    }
                    if check == "fwgrad_bwgrad":
                        kwargs["check_fwd_over_rev"] = True
                        kwargs["check_rev_over_rev"] = False
                        kwargs["check_batched_grad"] = False
                        kwargs["check_undefined_grad"] = False

                    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
            else:
                self.assertTrue(False, msg="Unknown check requested!")
Пример #14
0
    def test_out_warning(self, device, op):
        # TODO: verify the op doesn't support the out= kwarg
        if not op.supports_out:
            self.skipTest("Skipped! Op doesn't support out= kwarg.")

        # Prefers running in float32 but has a fallback for the first listed supported dtype
        supported_dtypes = op.supported_dtypes(self.device_type)
        if len(supported_dtypes) == 0:
            self.skipTest(
                "Skipped! Op has not supported dtypes on this device.")
        dtype = torch.float32 if torch.float32 in supported_dtypes else list(
            supported_dtypes)[0]

        # NOTE: only tests on first sample
        samples = op.sample_inputs(device, dtype)
        sample = first_sample(self, samples)

        # calls it normally to get the expected result
        expected = op(sample.input, *sample.args, **sample.kwargs)
        op_out = partial(op, sample.input, *sample.args, **sample.kwargs)

        # Short-circuits if output is not a single tensor or an
        #   iterable of tensors

        if not isinstance(expected,
                          torch.Tensor) and not is_iterable_of_tensors(
                              expected, include_empty=True):
            self.skipTest(
                "Skipped! Only supports single tensor or iterable of tensor outputs."
            )

        # A wrapper around map that works with single tensors and always
        #   instantiates the map. Used below to apply transforms to
        #   single tensor and iterable tensor outputs.
        def _apply_out_transform(fn, out):
            if isinstance(out, torch.Tensor):
                return fn(out)

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(fn, out))

        # Extracts strides from a tensor or iterable of tensors into a tuple
        def _extract_strides(out):
            if isinstance(out, torch.Tensor):
                return (out.stride(), )

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(lambda t: t.stride(), out))

        # Extracts data pointers from a tensor or iterable of tensors into a tuple
        # NOTE: only extracts on the CPU and CUDA device types since some
        #   device types don't have storage
        def _extract_data_ptrs(out):
            if self.device_type != 'cpu' and self.device_type != 'cuda':
                return ()

            if isinstance(out, torch.Tensor):
                return (out.data_ptr(), )

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(lambda t: t.data_ptr(), out))

        def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
            out = _apply_out_transform(transform, expected)
            original_strides = _extract_strides(out)
            original_ptrs = _extract_data_ptrs(out)

            op_out(out=out)
            final_strides = _extract_strides(out)
            final_ptrs = _extract_data_ptrs(out)

            self.assertEqual(expected, out)

            if compare_strides_and_data_ptrs:
                self.assertEqual(original_strides, final_strides)
                self.assertEqual(original_ptrs, final_ptrs)

        # Case: out= with the correct dtype and device, but the wrong shape
        #   Expected behavior: resize with a warning.
        def _case_two_transform(t):
            wrong_shape = list(t.shape)

            if len(wrong_shape) == 0:
                # Handles scalar tensor case (empty list)
                wrong_shape = [2]
            else:
                wrong_shape[-1] = wrong_shape[-1] + 1
            return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)

        _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False)

        # Additional validates that the appropriate warning is thrown
        out = _apply_out_transform(_case_two_transform, expected)
        msg_fail = "Resized a non-empty tensor but did not warn about it."
        with self.assertWarnsRegex(UserWarning,
                                   "An output with one or more elements",
                                   msg=msg_fail):
            op_out(out=out)
Пример #15
0
 def _is_tensor_input(arg):
     return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)
Пример #16
0
    def test_out(self, device, op):
        # Prefers running in float32 but has a fallback for the first listed supported dtype
        supported_dtypes = op.supported_dtypes(self.device_type)
        if len(supported_dtypes) == 0:
            self.skipTest(
                "Skipped! Op has not supported dtypes on this device.")
        dtype = torch.float32 if torch.float32 in supported_dtypes else list(
            supported_dtypes)[0]

        samples = op.sample_inputs(device, dtype)
        for sample in samples:
            # calls it normally to get the expected result
            expected = op(sample.input, *sample.args, **sample.kwargs)
            op_out = partial(op, sample.input, *sample.args, **sample.kwargs)

            # Short-circuits if output is not a single tensor or an
            #   iterable of tensors
            if not isinstance(expected,
                              torch.Tensor) and not is_iterable_of_tensors(
                                  expected, include_empty=True):
                self.skipTest(
                    "Skipped! Only supports single tensor or iterable of tensor outputs."
                )

            # Validates the op doesn't support out if it claims not to
            if not op.supports_out:
                with self.assertRaises(Exception):
                    assert op_out(out=expected) != NotImplemented
                return

            # A wrapper around map that works with single tensors and always
            #   instantiates the map. Used below to apply transforms to
            #   single tensor and iterable tensor outputs.
            def _apply_out_transform(fn, out):
                if isinstance(out, torch.Tensor):
                    return fn(out)

                # assumes (see above) that out is an iterable of tensors
                return tuple(map(fn, out))

            # Extracts strides from a tensor or iterable of tensors into a tuple
            def _extract_strides(out):
                if isinstance(out, torch.Tensor):
                    return (out.stride(), )

                # assumes (see above) that out is an iterable of tensors
                return tuple(map(lambda t: t.stride(), out))

            # Extracts data pointers from a tensor or iterable of tensors into a tuple
            # NOTE: only extracts on the CPU and CUDA device types since some
            #   device types don't have storage
            def _extract_data_ptrs(out):
                if self.device_type != 'cpu' and self.device_type != 'cuda':
                    return ()

                if isinstance(out, torch.Tensor):
                    return (out.data_ptr(), )

                # assumes (see above) that out is an iterable of tensors
                return tuple(map(lambda t: t.data_ptr(), out))

            def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
                out = _apply_out_transform(transform, expected)
                original_strides = _extract_strides(out)
                original_ptrs = _extract_data_ptrs(out)

                op_out(out=out)
                final_strides = _extract_strides(out)
                final_ptrs = _extract_data_ptrs(out)

                self.assertEqual(expected, out)

                if compare_strides_and_data_ptrs:
                    stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
                        original_strides, final_strides)
                    self.assertEqual(original_strides,
                                     final_strides,
                                     msg=stride_msg)
                    self.assertEqual(original_ptrs, final_ptrs)

            # Case 0: out= with the correct shape, dtype, and device
            #   but NaN values for floating point and complex tensors, and
            #   maximum values for integer tensors.
            #   Expected behavior: out= values have no effect on the computation.
            def _case_zero_transform(t):
                try:
                    info = torch.iinfo(t.dtype)
                    return torch.full_like(t, info.max)
                except TypeError as te:
                    # for non-integer types fills with NaN
                    return torch.full_like(t, float('nan'))

            _compare_out(_case_zero_transform)

            # Case 1: out= with the correct shape, dtype, and device,
            #   but noncontiguous.
            #   Expected behavior: strides are respected and `out` storage is not changed.
            def _case_one_transform(t):
                return make_tensor(t.shape,
                                   dtype=t.dtype,
                                   device=t.device,
                                   noncontiguous=True)

            _compare_out(_case_one_transform)

            # Case 2: out= with the correct dtype and device, but has no elements.
            #   Expected behavior: resize without warning.
            def _case_two_transform(t):
                return make_tensor((0, ), dtype=t.dtype, device=t.device)

            _compare_out(_case_two_transform,
                         compare_strides_and_data_ptrs=False)

            # Also validates that no warning is thrown when this out is resized
            out = _apply_out_transform(_case_two_transform, expected)
            with warnings.catch_warnings(record=True) as caught:
                warnings.simplefilter("always")
                op_out(out=out)

            # Verifies no warning is a resize warning
            for w in caught:
                if "An output with one or more elements" in str(w.message):
                    self.fail(
                        "Resizing an out= argument with no elements threw a resize warning!"
                    )

            # Case 3: out= with correct shape and dtype, but wrong device.
            wrong_device = None
            if torch.device(device).type != 'cpu':
                wrong_device = 'cpu'
            elif torch.cuda.is_available():
                wrong_device = 'cuda'

            if wrong_device is not None:

                def _case_three_transform(t):
                    return make_tensor(t.shape,
                                       dtype=t.dtype,
                                       device=wrong_device)

                out = _apply_out_transform(_case_three_transform, expected)
                msg_fail = f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}"
                with self.assertRaises(RuntimeError, msg=msg_fail):
                    op_out(out=out)

            # Case 4: out= with correct shape and device, but a dtype
            #   that output cannot be "safely" cast to (long).
            #   Expected behavior: error.
            # NOTE: this case is filtered by dtype since some ops produce
            #   bool tensors, for example, which can be safely cast to any
            #   dtype. It is applied when single tensors are floating point or complex
            #   dtypes, or if an op returns multiple tensors when at least one such
            #   tensor is a floating point or complex dtype.
            _dtypes = floating_and_complex_types_and(torch.float16,
                                                     torch.bfloat16)
            if (isinstance(expected, torch.Tensor)
                    and expected.dtype in _dtypes
                    or (not isinstance(expected, torch.Tensor)
                        and any(t.dtype in _dtypes for t in expected))):

                def _case_four_transform(t):
                    return make_tensor(t.shape,
                                       dtype=torch.long,
                                       device=t.device)

                out = _apply_out_transform(_case_four_transform, expected)
                msg_fail = "Expected RuntimeError when doing an unsafe cast!"
                msg_fail = msg_fail if not isinstance(expected, torch.Tensor) else \
                    ("Expected RuntimeError when doing an unsafe cast from a result of dtype "
                        f"{expected.dtype} into an out= with dtype torch.long")
                with self.assertRaises(RuntimeError, msg=msg_fail):
                    op_out(out=out)