Example #1
0
 def allSum(vs):
     if isinstance(vs, torch.Tensor):
         vs = (vs, )
     return sum(
         (i + 1) * v.sum() for i, v in enumerate(vs)
         if v is not None and v.dtype in floating_and_complex_types_and(
             torch.half, torch.bfloat16))
Example #2
0
    def test_floating_inputs_are_differentiable(self, device, dtype, op):
        # Nothing to check if the operation it's not differentiable
        if not op.supports_autograd:
            return

        floating_dtypes = list(
            floating_and_complex_types_and(torch.bfloat16, torch.float16))

        def check_tensor_floating_is_differentiable(t):
            if isinstance(t, torch.Tensor) and t.dtype in floating_dtypes:
                msg = (
                    f"Found a sampled tensor of floating-point dtype {t.dtype} sampled with "
                    "requires_grad=False. If this is intended, please skip/xfail this test. "
                    "Remember that sampling operations are executed under a torch.no_grad contextmanager."
                )
                self.assertTrue(t.requires_grad, msg)

        samples = op.sample_inputs(device, dtype, requires_grad=True)
        for sample in samples:
            check_tensor_floating_is_differentiable(sample.input)
            for arg in sample.args:
                check_tensor_floating_is_differentiable(arg)
            for arg in sample.kwargs.values():
                check_tensor_floating_is_differentiable(arg)
Example #3
0
    def test_dtypes(self, device, op):
        # dtypes to try to backward in
        allowed_backward_dtypes = floating_and_complex_types_and(
            torch.bfloat16, torch.float16)

        # lists for (un)supported dtypes
        supported_dtypes = []
        unsupported_dtypes = []
        supported_backward_dtypes = []
        unsupported_backward_dtypes = []

        def unsupported(dtype):
            unsupported_dtypes.append(dtype)
            if dtype in allowed_backward_dtypes:
                unsupported_backward_dtypes.append(dtype)

        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16,
                                               torch.bool):
            # tries to acquire samples - failure indicates lack of support
            requires_grad = (dtype in allowed_backward_dtypes
                             and op.supports_autograd)
            try:
                samples = list(
                    op.sample_inputs(device,
                                     dtype,
                                     requires_grad=requires_grad))
            except Exception as e:
                unsupported(dtype)
                continue

            # Counts number of successful backward attempts
            # NOTE: This exists as a kludge because this only understands how to
            #   request a gradient if the output is a tensor or a sequence with
            #   a tensor as its first element.
            num_backward_successes = 0
            for sample in samples:
                # tries to call operator with the sample - failure indicates
                #   lack of support
                try:
                    result = op(sample.input, *sample.args, **sample.kwargs)
                except Exception as e:
                    # NOTE: some ops will fail in forward if their inputs
                    #   require grad but they don't support computing the gradient
                    #   in that type! This is a bug in the op!
                    unsupported(dtype)

                # Short-circuits testing this dtype -- it doesn't work
                if dtype in unsupported_dtypes:
                    break

                # Short-circuits if the dtype isn't a backward dtype or
                #   it's already identified as not supported
                if dtype not in allowed_backward_dtypes or dtype in unsupported_backward_dtypes:
                    continue

                # Checks for backward support in the same dtype
                try:
                    result = sample.output_process_fn_grad(result)
                    if isinstance(result, torch.Tensor):
                        backward_tensor = result
                    elif isinstance(result, Sequence) and isinstance(
                            result[0], torch.Tensor):
                        backward_tensor = result[0]
                    else:
                        continue

                    # Note: this grad may not have the same dtype as dtype
                    # For functions like complex (float -> complex) or abs
                    #   (complex -> float) the grad tensor will have a
                    #   different dtype than the input.
                    #   For simplicity, this is still modeled as these ops
                    #   supporting grad in the input dtype.
                    grad = torch.randn_like(backward_tensor)
                    backward_tensor.backward(grad)
                    num_backward_successes += 1
                except Exception as e:
                    unsupported_backward_dtypes.append(dtype)

            if dtype not in unsupported_dtypes:
                supported_dtypes.append(dtype)
            if num_backward_successes > 0 and dtype not in unsupported_backward_dtypes:
                supported_backward_dtypes.append(dtype)

        # Checks that dtypes are listed correctly and generates an informative
        #   error message
        device_type = torch.device(device).type
        claimed_supported = set(op.supported_dtypes(device_type))
        supported_dtypes = set(supported_dtypes)
        supported_but_unclaimed = supported_dtypes - claimed_supported
        claimed_but_unsupported = claimed_supported - supported_dtypes
        msg = """The supported dtypes for {0} on {1} according to its OpInfo are
        {2}, but the detected supported dtypes are {3}.
        """.format(op.name, device_type, claimed_supported, supported_dtypes)

        if len(supported_but_unclaimed) > 0:
            msg += "The following dtypes should be added to the OpInfo: {0}. ".format(
                supported_but_unclaimed)
        if len(claimed_but_unsupported) > 0:
            msg += "The following dtypes should be removed from the OpInfo: {0}.".format(
                claimed_but_unsupported)

        self.assertEqual(supported_dtypes, claimed_supported, msg=msg)

        # Checks that backward dtypes are listed correctly and generates an
        #   informative error message
        # NOTE: this code is nearly identical to the check + msg generation
        claimed_backward_supported = set(
            op.supported_backward_dtypes(device_type))
        supported_backward_dtypes = set(supported_backward_dtypes)

        supported_but_unclaimed = supported_backward_dtypes - claimed_backward_supported
        claimed_but_unsupported = claimed_backward_supported - supported_backward_dtypes
        msg = """The supported backward dtypes for {0} on {1} according to its OpInfo are
        {2}, but the detected supported backward dtypes are {3}.
        """.format(op.name, device_type, claimed_backward_supported,
                   supported_backward_dtypes)

        if len(supported_but_unclaimed) > 0:
            msg += "The following backward dtypes should be added to the OpInfo: {0}. ".format(
                supported_but_unclaimed)
        if len(claimed_but_unsupported) > 0:
            msg += "The following backward dtypes should be removed from the OpInfo: {0}.".format(
                claimed_but_unsupported)

        self.assertEqual(supported_backward_dtypes,
                         claimed_backward_supported,
                         msg=msg)
Example #4
0
    def test_out(self, device, op):
        # TODO: verify the op doesn't support the out= kwarg
        if not op.supports_out:
            self.skipTest("Skipped! Op doesn't support out= kwarg.")

        # Prefers running in float32 but has a fallback for the first listed supported dtype
        supported_dtypes = op.supported_dtypes(self.device_type)
        if len(supported_dtypes) == 0:
            self.skipTest(
                "Skipped! Op has not supported dtypes on this device.")
        dtype = torch.float32 if torch.float32 in supported_dtypes else list(
            supported_dtypes)[0]

        # NOTE: only tests on first sample
        samples = op.sample_inputs(device, dtype)
        sample = first_sample(self, samples)

        # calls it normally to get the expected result
        expected = op(sample.input, *sample.args, **sample.kwargs)
        op_out = partial(op, sample.input, *sample.args, **sample.kwargs)

        # Short-circuits if output is not a single tensor or an
        #   iterable of tensors

        if not isinstance(expected,
                          torch.Tensor) and not is_iterable_of_tensors(
                              expected, include_empty=True):
            self.skipTest(
                "Skipped! Only supports single tensor or iterable of tensor outputs."
            )

        # A wrapper around map that works with single tensors and always
        #   instantiates the map. Used below to apply transforms to
        #   single tensor and iterable tensor outputs.
        def _apply_out_transform(fn, out):
            if isinstance(out, torch.Tensor):
                return fn(out)

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(fn, out))

        # Extracts strides from a tensor or iterable of tensors into a tuple
        def _extract_strides(out):
            if isinstance(out, torch.Tensor):
                return (out.stride(), )

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(lambda t: t.stride(), out))

        # Extracts data pointers from a tensor or iterable of tensors into a tuple
        # NOTE: only extracts on the CPU and CUDA device types since some
        #   device types don't have storage
        def _extract_data_ptrs(out):
            if self.device_type != 'cpu' and self.device_type != 'cuda':
                return ()

            if isinstance(out, torch.Tensor):
                return (out.data_ptr(), )

            # assumes (see above) that out is an iterable of tensors
            return tuple(map(lambda t: t.data_ptr(), out))

        def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
            out = _apply_out_transform(transform, expected)
            original_strides = _extract_strides(out)
            original_ptrs = _extract_data_ptrs(out)

            op_out(out=out)
            final_strides = _extract_strides(out)
            final_ptrs = _extract_data_ptrs(out)

            self.assertEqual(expected, out)

            if compare_strides_and_data_ptrs:
                self.assertEqual(original_strides, final_strides)
                self.assertEqual(original_ptrs, final_ptrs)

        # Case 0: out= with the correct shape, dtype, and device
        #   but NaN values for floating point and complex tensors, and
        #   maximum values for integer tensors.
        #   Expected behavior: out= values have no effect on the computation.
        def _case_zero_transform(t):
            try:
                info = torch.iinfo(t.dtype)
                return torch.full_like(t, info.max)
            except TypeError as te:
                # for non-integer types fills with NaN
                return torch.full_like(t, float('nan'))

        _compare_out(_case_zero_transform)

        # Case 1: out= with the correct shape, dtype, and device,
        #   but noncontiguous.
        #   Expected behavior: strides are respected and `out` storage is not changed.
        def _case_one_transform(t):
            return make_tensor(t.shape,
                               dtype=t.dtype,
                               device=t.device,
                               noncontiguous=True)

        _compare_out(_case_one_transform)

        # Case 2: out= with the correct dtype and device, but has no elements.
        #   Expected behavior: resize without warning.
        def _case_two_transform(t):
            return make_tensor((0, ), dtype=t.dtype, device=t.device)

        _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False)

        # Also validates that no warning is thrown when this out is resized
        out = _apply_out_transform(_case_two_transform, expected)
        with warnings.catch_warnings(record=True) as caught:
            warnings.simplefilter("always")
            op_out(out=out)

        # Verifies no warning is a resize warning
        for w in caught:
            if "An output with one or more elements" in str(w.message):
                self.fail(
                    "Resizing an out= argument with no elements threw a resize warning!"
                )

        # Case 3: out= with correct shape and dtype, but wrong device.
        wrong_device = None
        if torch.device(device).type != 'cpu':
            wrong_device = 'cpu'
        elif torch.cuda.is_available():
            wrong_device = 'cuda'

        if wrong_device is not None:

            def _case_three_transform(t):
                return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)

            out = _apply_out_transform(_case_three_transform, expected)
            msg_fail = f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}"
            with self.assertRaises(RuntimeError, msg=msg_fail):
                op_out(out=out)

        # Case 4: out= with correct shape and device, but a dtype
        #   that output cannot be "safely" cast to (long).
        #   Expected behavior: error.
        # NOTE: this case is filtered by dtype since some ops produce
        #   bool tensors, for example, which can be safely cast to any
        #   dtype. It is applied when single tensors are floating point or complex
        #   dtypes, or if an op returns multiple tensors when at least one such
        #   tensor is a floating point or complex dtype.
        _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
        if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes
                or (not isinstance(expected, torch.Tensor)
                    and any(t.dtype in _dtypes for t in expected))):

            def _case_four_transform(t):
                return make_tensor(t.shape, dtype=torch.long, device=t.device)

            out = _apply_out_transform(_case_four_transform, expected)
            msg_fail = "" if not isinstance(expected, torch.Tensor) else \
                       ("Expected RuntimeError when doing an unsafe cast from a result of dtype "
                        f"{expected.dtype} into an out= with dtype torch.long")
            with self.assertRaises(RuntimeError, msg=msg_fail):
                op_out(out=out)
Example #5
0
             unittest.expectedFailure,
             "TestNormalizeOperators",
             "test_normalize_operator_exhaustive",
         ),
         DecorateInfo(
             unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
         ),
     ),
     gradcheck_wrapper=gradcheck_wrapper_masked_operation,
     supports_forward_ad=True,
     supports_out=False,
 ),
 OpInfo(
     "_masked.normalize",
     method_variant=None,
     dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
     sample_inputs_func=sample_inputs_masked_normalize,
     skips=(
         DecorateInfo(
             unittest.expectedFailure,
             "TestNormalizeOperators",
             "test_normalize_operator_exhaustive",
         ),
         DecorateInfo(
             unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
         ),
         # RuntimeError: "clamp_min_cpu" not implemented for 'Half'
         DecorateInfo(
             unittest.expectedFailure,
             "TestMasked",
             "test_reference_masked",
Example #6
0
    def test_dtypes(self, device, op):
        # Check complex32 support only if the op claims.
        # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
        device_type = torch.device(device).type
        include_complex32 = ((torch.complex32, ) if op.supports_dtype(
            torch.complex32, device_type) else ())

        # dtypes to try to backward in
        allowed_backward_dtypes = floating_and_complex_types_and(
            *((torch.half, torch.bfloat16) + include_complex32))

        # lists for (un)supported dtypes
        supported_dtypes = set()
        unsupported_dtypes = set()
        supported_backward_dtypes = set()
        unsupported_backward_dtypes = set()

        def unsupported(dtype):
            unsupported_dtypes.add(dtype)
            if dtype in allowed_backward_dtypes:
                unsupported_backward_dtypes.add(dtype)

        for dtype in all_types_and_complex_and(
                *((torch.half, torch.bfloat16, torch.bool) +
                  include_complex32)):
            # tries to acquire samples - failure indicates lack of support
            requires_grad = (dtype in allowed_backward_dtypes)
            try:
                samples = tuple(
                    op.sample_inputs(device,
                                     dtype,
                                     requires_grad=requires_grad))
            except Exception as e:
                unsupported(dtype)
                continue

            for sample in samples:
                # tries to call operator with the sample - failure indicates
                #   lack of support
                try:
                    result = op(sample.input, *sample.args, **sample.kwargs)
                    supported_dtypes.add(dtype)
                except Exception as e:
                    # NOTE: some ops will fail in forward if their inputs
                    #   require grad but they don't support computing the gradient
                    #   in that type! This is a bug in the op!
                    unsupported(dtype)
                    continue

                # Checks for backward support in the same dtype
                try:
                    result = sample.output_process_fn_grad(result)
                    if isinstance(result, torch.Tensor):
                        backward_tensor = result
                    elif isinstance(result, Sequence) and isinstance(
                            result[0], torch.Tensor):
                        backward_tensor = result[0]
                    else:
                        continue

                    # Note: this grad may not have the same dtype as dtype
                    # For functions like complex (float -> complex) or abs
                    #   (complex -> float) the grad tensor will have a
                    #   different dtype than the input.
                    #   For simplicity, this is still modeled as these ops
                    #   supporting grad in the input dtype.
                    grad = torch.randn_like(backward_tensor)
                    backward_tensor.backward(grad)
                    supported_backward_dtypes.add(dtype)
                except Exception as e:
                    unsupported_backward_dtypes.add(dtype)

        # Checks that dtypes are listed correctly and generates an informative
        #   error message
        supported_forward = supported_dtypes - unsupported_dtypes
        partially_supported_forward = supported_dtypes & unsupported_dtypes
        unsupported_forward = unsupported_dtypes - supported_dtypes
        supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
        partially_supported_backward = supported_backward_dtypes & unsupported_backward_dtypes
        unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes

        device_type = torch.device(device).type

        claimed_forward = set(op.supported_dtypes(device_type))
        supported_but_unclaimed_forward = supported_forward - claimed_forward
        claimed_but_unsupported_forward = claimed_forward & unsupported_forward

        claimed_backward = set(op.supported_backward_dtypes(device_type))
        supported_but_unclaimed_backward = supported_backward - claimed_backward
        claimed_but_unsupported_backward = claimed_backward & unsupported_backward

        # Partially supporting a dtype is not an error, but we print a warning
        if (len(partially_supported_forward) +
                len(partially_supported_backward)) > 0:
            msg = "Some dtypes for {0} on device type {1} are only partially supported!\n".format(
                op.name, device_type)
            if len(partially_supported_forward) > 0:
                msg = msg + "The following dtypes only worked on some samples during forward: {0}.\n".format(
                    partially_supported_forward)
            if len(partially_supported_backward) > 0:
                msg = msg + "The following dtypes only worked on some samples during backward: {0}.\n".format(
                    partially_supported_backward)
            print(msg)

        if (len(supported_but_unclaimed_forward) +
                len(claimed_but_unsupported_forward) +
                len(supported_but_unclaimed_backward) +
                len(claimed_but_unsupported_backward)) == 0:
            return

        # Generates error msg
        msg = "The supported dtypes for {0} on device type {1} are incorrect!\n".format(
            op.name, device_type)
        if len(supported_but_unclaimed_forward) > 0:
            msg = msg + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format(
                supported_but_unclaimed_forward)
        if len(supported_but_unclaimed_backward) > 0:
            msg = msg + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format(
                supported_but_unclaimed_backward)
        if len(claimed_but_unsupported_forward) > 0:
            msg = msg + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format(
                claimed_but_unsupported_forward)
        if len(claimed_but_unsupported_backward) > 0:
            msg = msg + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format(
                claimed_but_unsupported_backward)

        self.fail(msg)