示例#1
0
def _make_tensor(size, device: torch.device, dtype: torch.dtype, *, low,
                 high) -> torch.Tensor:
    if dtype is torch.bool:
        return torch.randint(0, 2, size, device=device, dtype=dtype)

    def _maybe_clamp(t, low_default, low, high_default, high):
        low = low_default if low is None else max(low_default, low)
        high = high_default if high is None else min(high_default, high)
        if low != low_default or high != high_default:
            return torch.clamp(t, low, high)
        return t

    if dtype is torch.uint8:
        t = torch.randint(0, 10, size, device=device, dtype=dtype)
        return _maybe_clamp(t, 0, low, 9, high)
    elif dtype in integral_types():
        t = torch.randint(-9, 10, size, device=device, dtype=dtype)
        return _maybe_clamp(t, -9, low, 9, high)
    elif dtype in floating_types_and(torch.half, torch.bfloat16):
        # Windows doesn't support torch.rand(bfloat16) on CUDA
        if IS_WINDOWS and torch.device(
                device).type == 'cuda' and dtype is torch.bfloat16:
            t = (torch.rand(size, device=device, dtype=torch.float32) * 18 -
                 9).to(torch.bfloat16)
        else:
            t = torch.rand(size, device=device, dtype=dtype) * 18 - 9
        return _maybe_clamp(t, -9, low, 9, high)
    else:
        assert dtype in complex_types()
        float_dtype = torch.float if dtype is torch.cfloat else torch.double
        real = torch.rand(size, device=device, dtype=float_dtype) * 18 - 9
        real = _maybe_clamp(real, -9, low, 9, high)
        imag = torch.rand(size, device=device, dtype=float_dtype) * 18 - 9
        imag = _maybe_clamp(imag, -9, low, 9, high)
        return torch.complex(real, imag)
示例#2
0
class TestUnaryUfuncs(TestCase):
    exact_dtype = True

    # Tests bool tensor negation raises the correct error
    def test_neg_error_message(self, device):
        msg = ("Negation, the `\\-` operator, on a bool tensor is not supported."
               " If you are trying to invert a mask, use the `\\~` or"
               " `logical_not\\(\\)` operator instead.")

        t = torch.tensor((False, True), device=device)

        with self.assertRaisesRegex(RuntimeError, msg):
            torch.neg(t)

    @dtypes(*floating_types_and(torch.bfloat16, torch.half))
    @ops((_fn for _fn in unary_ufuncs if _fn.domain != (None, None)))
    def test_float_domains(self, device, dtype, op):
        if not op.supports_dtype(dtype, torch.device(device).type):
            raise unittest.SkipTest('unsupported dtype')

        eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100)

        low, high = op.domain
        # NOTE: the following two loops are separated for readability
        if low is not None:
            low_tensor = torch.tensor(low, device=device, dtype=dtype)
            for epsilon in eps:
                lower_tensor = low_tensor - epsilon

                # Skips the test if the difference is not representable,
                #   which can occur if, for example, the difference is small
                #   and the dtype is imprecise (like bfloat16 is)
                if lower_tensor.item() == low_tensor.item():
                    continue

                result = op(lower_tensor)
                self.assertEqual(result.item(), float('nan'),
                                 msg=("input of {0} outside lower domain boundary"
                                      " {1} produced {2}, not nan!").format(lower_tensor.item(),
                                                                            low,
                                                                            result.item()))

        if high is not None:
            high_tensor = torch.tensor(high, device=device, dtype=dtype)
            for epsilon in eps:
                higher_tensor = high_tensor + epsilon

                # See above comment
                if higher_tensor.item() == high_tensor.item():
                    continue

                result = op(higher_tensor)
                self.assertEqual(result.item(), float('nan'),
                                 msg=("input of {0} outside upper domain boundary"
                                      " {1} produced {2}, not nan!").format(higher_tensor.item(),
                                                                            high,
                                                                            result.item()))

    # Tests that fn == method == inplace == jit on a simple single tensor input
    # TODO: should this jitting the method and inplace variants, too?
    @ops(unary_ufuncs)
    def test_variant_consistency(self, device, dtype, op):
        def _fn(t):
            return op(t)

        t = make_tensor((5, 5), device, dtype, low=op.domain[0], high=op.domain[1])
        expected = op(t)

        for alt, inplace in ((op.get_method(), False), (op.get_inplace(), True),
                             (torch.jit.script(_fn), False)):
            if alt is None:
                with self.assertRaises(RuntimeError):
                    alt(t.clone())

            if inplace and op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,):
                # Assert that RuntimeError is raised
                # for inplace variant of Operators that
                # promote integer input to floating dtype.
                with self.assertRaises(RuntimeError):
                    alt(t.clone())
                continue

            actual = alt(t.clone())
            self.assertEqual(actual, expected, rtol=0, atol=0)

    # Helper for comparing torch tensors and numpy arrays
    # TODO: should this or assertEqual also validate that strides are equal?
    def assertEqualHelper(self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs):
        assert isinstance(actual, torch.Tensor)

        # Some NumPy functions return scalars, not arrays
        if isinstance(expected, Number):
            self.assertEqual(actual.item(), expected, **kwargs)
        elif isinstance(expected, np.ndarray):
            # Handles exact dtype comparisons between arrays and tensors
            if exact_dtype:
                # Allows array dtype to be float32 when comparing with bfloat16 tensors
                #   since NumPy doesn't support the bfloat16 dtype
                if expected.dtype == np.float32:
                    assert actual.dtype in (torch.bfloat16, torch.float32)
                else:
                    assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype]

            self.assertEqual(actual,
                             torch.from_numpy(expected).to(actual.dtype),
                             msg,
                             exact_device=False,
                             **kwargs)
        else:
            self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)

    # Tests that the function and its (array-accepting) reference produce the same
    #   values on a range of tensors, including empty tensors, scalar tensors,
    #   1D tensors and a large 2D tensor with interesting and extremal values
    #   and discontiguities.
    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
    @suppress_warnings
    @ops(unary_ufuncs)
    def test_reference_numerics(self, device, dtype, op):
        include_extremals = (op.handles_complex_extremals if
                             dtype in (torch.cfloat, torch.cdouble) else op.handles_extremals)

        tensors = generate_numeric_tensors(device, dtype,
                                           domain=op.domain,
                                           include_large_values=op.handles_large_floats,
                                           include_extremal_values=include_extremals)
        for t in tensors:
            if dtype is torch.bfloat16:
                a = t.cpu().to(torch.float32).numpy()
            else:
                a = t.cpu().numpy()

            actual = op(t)
            expected = op.ref(a)

            # Crafts a custom error message for smaller, printable tensors
            if t.numel() < 10:
                msg = ("Failed to produce expected results! Input tensor was"
                       " {0}, torch result is {1}, and reference result is"
                       " {2}.").format(t, actual, expected)
            else:
                msg = None

            exact_dtype = True
            if op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,):
                exact_dtype = False

                if dtype in [torch.uint8, torch.int8, torch.bool]:
                    # NOTE: For these dtypes, PyTorch computes in the default scalar type (float)
                    # while NumPy computes in float16
                    self.assertEqualHelper(actual, expected, msg, dtype=dtype,
                                           exact_dtype=exact_dtype, rtol=1e-3, atol=1e-2)
                    continue

            self.assertEqualHelper(actual, expected, msg, dtype=dtype, exact_dtype=exact_dtype)

    # Tests for testing (dis)contiguity consistency

    @ops(unary_ufuncs)
    def test_contig_vs_every_other(self, device, dtype, op):
        contig = make_tensor((1026,), device=device, dtype=dtype,
                             low=op.domain[0], high=op.domain[1])
        non_contig = contig[::2]

        self.assertTrue(contig.is_contiguous())
        self.assertFalse(non_contig.is_contiguous())

        self.assertEqual(op(contig)[::2], op(non_contig))

    @ops(unary_ufuncs)
    def test_contig_vs_transposed(self, device, dtype, op):
        contig = make_tensor((789, 357), device=device, dtype=dtype,
                             low=op.domain[0], high=op.domain[1])
        non_contig = contig.T

        self.assertTrue(contig.is_contiguous())
        self.assertFalse(non_contig.is_contiguous())

        self.assertEqual(op(contig).T, op(non_contig))

    @ops(unary_ufuncs)
    def test_non_contig(self, device, dtype, op):
        shapes = [(5, 7), (1024,)]
        for shape in shapes:
            contig = make_tensor(shape, device, dtype,
                                 low=op.domain[0], high=op.domain[1])
            non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
            non_contig.copy_(contig)

            self.assertTrue(contig.is_contiguous())
            self.assertFalse(non_contig.is_contiguous())

            self.assertEqual(op(contig), op(non_contig))

    @ops(unary_ufuncs)
    def test_non_contig_index(self, device, dtype, op):
        contig = make_tensor((2, 2, 1, 2), device, dtype,
                             low=op.domain[0], high=op.domain[1])
        non_contig = contig[:, 1, ...]
        contig = non_contig.contiguous()

        self.assertTrue(contig.is_contiguous())
        self.assertFalse(non_contig.is_contiguous())

        self.assertEqual(op(contig), op(non_contig))

    @ops(unary_ufuncs)
    def test_non_contig_expand(self, device, dtype, op):
        shapes = [(1, 3), (1, 7), (5, 7)]
        for shape in shapes:
            contig = make_tensor(shape, device, dtype,
                                 low=op.domain[0], high=op.domain[1])
            non_contig = contig.clone().expand(3, -1, -1)

            self.assertTrue(contig.is_contiguous())
            self.assertFalse(non_contig.is_contiguous())

            contig = op(contig)
            non_contig = op(non_contig)
            for i in range(3):
                self.assertEqual(contig, non_contig[i],
                                 msg='non-contiguous expand[' + str(i) + ']')

    @ops(unary_ufuncs)
    def test_contig_size1(self, device, dtype, op):
        contig = make_tensor((5, 100), device, dtype,
                             low=op.domain[0], high=op.domain[1])
        contig = contig[:1, :50]
        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
        contig2.copy_(contig)

        self.assertTrue(contig.is_contiguous())
        self.assertTrue(contig2.is_contiguous())

        self.assertEqual(op(contig), op(contig2))

    @ops(unary_ufuncs)
    def test_contig_size1_large_dim(self, device, dtype, op):
        contig = make_tensor((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), device, dtype,
                             low=op.domain[0], high=op.domain[1])
        contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
        contig2.copy_(contig)

        self.assertTrue(contig.is_contiguous())
        self.assertTrue(contig2.is_contiguous())

        self.assertEqual(op(contig), op(contig2))

    # Tests that computation on a multiple batches is the same as
    # per-batch computation.
    @ops(unary_ufuncs)
    def test_batch_vs_slicing(self, device, dtype, op):
        input = make_tensor((1024, 512), dtype=dtype, device=device,
                            low=op.domain[0], high=op.domain[1])

        actual = op(input)
        expected = torch.stack([op(slice) for slice in input])

        self.assertEqual(actual, expected)

    def _test_out_arg(self, op, input, output):
        dtype = input.dtype
        out_dtype = output.dtype
        if dtype is out_dtype:
            expected = op(input)
            op(input, out=output)
            self.assertEqual(output, expected)
        else:
            with self.assertRaises(RuntimeError):
                op(input, out=output)

    def _test_out_promote_int_to_float_op(self, op, input, output):
        def compare_out(op, input, out):
            out_dtype = out.dtype
            expected = op(input)
            op(input, out=out)
            self.assertEqual(out, expected.to(out_dtype))

        dtype = input.dtype
        out_dtype = output.dtype
        if out_dtype.is_floating_point and not dtype.is_complex:
            compare_out(op, input, output)
        elif out_dtype.is_floating_point and dtype.is_complex:
            # Can't cast complex to float
            with self.assertRaises(RuntimeError):
                op(input, out=output)
        elif out_dtype.is_complex:
            compare_out(op, input, output)
        else:
            # Can't cast to Integral types
            with self.assertRaises(RuntimeError):
                op(input, out=output)

    @ops(unary_ufuncs)
    def test_out_arg_all_dtypes(self, device, dtype, op):
        input = make_tensor((64, 64), dtype=dtype, device=device,
                            low=op.domain[0], high=op.domain[1])

        for out_dtype in all_types_and_complex_and(torch.bool, torch.half):
            out = torch.empty_like(input, dtype=out_dtype)
            if op.promotes_integers_to_float:
                self._test_out_promote_int_to_float_op(op, input, out)
            else:
                self._test_out_arg(op, input, out)

    @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool] +
              torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
    def test_nan_to_num(self, device, dtype):
        for contiguous in [False, True]:
            x = make_tensor((64, 64), low=0., high=100., dtype=dtype, device=device)

            if dtype.is_floating_point:
                # Add extremal values.
                extremals = [float('nan'), float('inf'), -float('inf')]
                for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals):
                    x[idx, :] = extremal

            if not contiguous:
                x = x.T

            # With args
            nan = random.random()
            posinf = random.random() * 5
            neginf = random.random() * 10

            self.compare_with_numpy(lambda x: x.nan_to_num(nan=nan, posinf=posinf),
                                    lambda x: np.nan_to_num(x, nan=nan, posinf=posinf),
                                    x)
            self.compare_with_numpy(lambda x: x.nan_to_num(posinf=posinf, neginf=neginf),
                                    lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf),
                                    x)

            # Out Variant
            out = torch.empty_like(x)
            result = torch.nan_to_num(x)
            torch.nan_to_num(x, out=out)
            self.assertEqual(result, out)

            result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
            torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf)
            self.assertEqual(result, out)
示例#3
0
class TestUnaryUfuncs(TestCase):
    exact_dtype = True

    # Helper for comparing torch tensors and numpy arrays
    # TODO: should this or assertEqual also validate that strides are equal?
    def assertEqualHelper(self,
                          actual,
                          expected,
                          *,
                          dtype,
                          exact_dtype=True,
                          **kwargs):
        assert isinstance(actual, torch.Tensor)

        # Some NumPy functions return scalars, not arrays
        if isinstance(expected, Number):
            self.assertEqual(actual.item(), expected)
        elif isinstance(expected, np.ndarray):
            # Handles exact dtype comparisons between arrays and tensors
            if exact_dtype:
                # Allows array dtype to be float32 when comparing with bfloat16 tensors
                #   since NumPy doesn't support the bfloat16 dtype
                if expected.dtype == np.float32:
                    assert actual.dtype in (torch.bfloat16, torch.float32)
                else:
                    assert expected.dtype == torch_to_numpy_dtype_dict[
                        actual.dtype]

            self.assertEqual(actual,
                             torch.from_numpy(expected).to(actual.dtype),
                             exact_device=False,
                             **kwargs)
        else:
            self.assertEqual(actual, expected, exact_device=False, **kwargs)

    # Verifies that the unary ufuncs have their supported dtypes
    #   registered correctly by testing that each unlisted dtype
    #   throws a runtime error
    @skipCUDAIfRocm
    @onlyOnCPUAndCUDA
    @ops(unary_ufuncs, unsupported_dtypes_only=True)
    def test_unsupported_dtypes(self, device, dtype, op):
        t = torch.empty(1, device=device, dtype=dtype)
        with self.assertRaises(RuntimeError):
            op(t)

    # Tests bool tensor negation raises the correct error
    def test_neg_error_message(self, device):
        msg = (
            "Negation, the `\\-` operator, on a bool tensor is not supported."
            " If you are trying to invert a mask, use the `\\~` or"
            " `logical_not\\(\\)` operator instead.")

        t = torch.tensor((False, True), device=device)

        with self.assertRaisesRegex(RuntimeError, msg):
            torch.neg(t)

    @dtypes(*floating_types_and(torch.bfloat16, torch.half))
    @ops((_fn for _fn in unary_ufuncs if _fn.domain != (None, None)))
    def test_float_domains(self, device, dtype, op):
        if not op.supports_dtype(dtype, torch.device(device).type):
            raise unittest.SkipTest('unsupported dtype')

        eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100)

        low, high = op.domain
        # NOTE: the following two loops are separated for readability
        if low is not None:
            low_tensor = torch.tensor(low, device=device, dtype=dtype)
            for epsilon in eps:
                lower_tensor = low_tensor - epsilon

                # Skips the test if the difference is not representable,
                #   which can occur if, for example, the difference is small
                #   and the dtype is imprecise (like bfloat16 is)
                if lower_tensor.item() == low_tensor.item():
                    continue

                result = op(lower_tensor)
                self.assertEqual(
                    result.item(),
                    float('nan'),
                    msg=("input of {0} outside lower domain boundary"
                         " {1} produced {2}, not nan!").format(
                             lower_tensor.item(), low, result.item()))

        if high is not None:
            high_tensor = torch.tensor(high, device=device, dtype=dtype)
            for epsilon in eps:
                higher_tensor = high_tensor + epsilon

                # See above comment
                if higher_tensor.item() == high_tensor.item():
                    continue

                result = op(higher_tensor)
                self.assertEqual(
                    result.item(),
                    float('nan'),
                    msg=("input of {0} outside upper domain boundary"
                         " {1} produced {2}, not nan!").format(
                             higher_tensor.item(), high, result.item()))

    # Tests that fn == method == inplace == jit on a simple single tensor input
    # TODO: should this jitting the method and inplace variants, too?
    @ops(unary_ufuncs)
    def test_variant_consistency(self, device, dtype, op):
        def _fn(t):
            return op(t)

        t = _make_tensor((5, 5),
                         device,
                         dtype,
                         low=op.domain[0],
                         high=op.domain[1])
        expected = op(t)

        for alt in (op.get_method(), op.get_inplace(), torch.jit.script(_fn)):
            if alt is None:
                with self.assertRaises(RuntimeError):
                    alt(t.clone())

            actual = alt(t.clone())
            self.assertEqual(actual, expected, rtol=0, atol=0)

    # Tests that the function and its (array-accepting) reference produce the same
    #   values on a range of tensors, including empty tensors, scalar tensors,
    #   1D tensors and a large 2D tensor with interesting and extremal values
    #   and discontiguities.
    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
    @suppress_warnings
    @ops(unary_ufuncs)
    def test_reference_numerics(self, device, dtype, op):
        include_extremals = (op.handles_complex_extremals if dtype
                             in (torch.cfloat,
                                 torch.cdouble) else op.handles_extremals)

        tensors = generate_numeric_tensors(
            device,
            dtype,
            domain=op.domain,
            include_large_values=op.handles_large_floats,
            include_extremal_values=include_extremals)
        for t in tensors:
            if dtype is torch.bfloat16:
                a = t.cpu().to(torch.float32).numpy()
            else:
                a = t.cpu().numpy()

            actual = op(t)
            expected = op.ref(a)

            # Crafts a custom error message for smaller, printable tensors
            if t.numel() < 10:
                msg = ("Failed to produce expected results! Input tensor was"
                       " {0}, torch result is {1}, and reference result is"
                       " {2}.").format(t, actual, expected)
            else:
                msg = None

            self.assertEqualHelper(actual, expected, dtype=dtype, msg=msg)

    # Tests for testing (dis)contiguity consistency

    @ops(unary_ufuncs)
    def test_contig_vs_every_other(self, device, dtype, op):
        contig = _make_tensor((1026, ),
                              device=device,
                              dtype=dtype,
                              low=op.domain[0],
                              high=op.domain[1])
        non_contig = contig[::2]

        self.assertTrue(contig.is_contiguous())
        self.assertFalse(non_contig.is_contiguous())

        self.assertEqual(op(contig)[::2], op(non_contig))

    @ops(unary_ufuncs)
    def test_contig_vs_transposed(self, device, dtype, op):
        contig = _make_tensor((789, 357),
                              device=device,
                              dtype=dtype,
                              low=op.domain[0],
                              high=op.domain[1])
        non_contig = contig.T

        self.assertTrue(contig.is_contiguous())
        self.assertFalse(non_contig.is_contiguous())

        self.assertEqual(op(contig).T, op(non_contig))

    @ops(unary_ufuncs)
    def test_non_contig(self, device, dtype, op):
        shapes = [(5, 7), (1024, )]
        for shape in shapes:
            contig = _make_tensor(shape,
                                  device,
                                  dtype,
                                  low=op.domain[0],
                                  high=op.domain[1])
            non_contig = torch.empty(shape + (2, ), device=device,
                                     dtype=dtype)[..., 0]
            non_contig.copy_(contig)

            self.assertTrue(contig.is_contiguous())
            self.assertFalse(non_contig.is_contiguous())

            self.assertEqual(op(contig), op(non_contig))

    @ops(unary_ufuncs)
    def test_non_contig_index(self, device, dtype, op):
        contig = _make_tensor((2, 2, 1, 2),
                              device,
                              dtype,
                              low=op.domain[0],
                              high=op.domain[1])
        non_contig = contig[:, 1, ...]
        contig = non_contig.contiguous()

        self.assertTrue(contig.is_contiguous())
        self.assertFalse(non_contig.is_contiguous())

        self.assertEqual(op(contig), op(non_contig))

    @ops(unary_ufuncs)
    def test_non_contig_expand(self, device, dtype, op):
        shapes = [(1, 3), (1, 7), (5, 7)]
        for shape in shapes:
            contig = _make_tensor(shape,
                                  device,
                                  dtype,
                                  low=op.domain[0],
                                  high=op.domain[1])
            non_contig = contig.clone().expand(3, -1, -1)

            self.assertTrue(contig.is_contiguous())
            self.assertFalse(non_contig.is_contiguous())

            contig = op(contig)
            non_contig = op(non_contig)
            for i in range(3):
                self.assertEqual(contig,
                                 non_contig[i],
                                 msg='non-contiguous expand[' + str(i) + ']')

    @ops(unary_ufuncs)
    def test_contig_size1(self, device, dtype, op):
        contig = _make_tensor((5, 100),
                              device,
                              dtype,
                              low=op.domain[0],
                              high=op.domain[1])
        contig = contig[:1, :50]
        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
        contig2.copy_(contig)

        self.assertTrue(contig.is_contiguous())
        self.assertTrue(contig2.is_contiguous())

        self.assertEqual(op(contig), op(contig2))

    @ops(unary_ufuncs)
    def test_contig_size1_large_dim(self, device, dtype, op):
        contig = _make_tensor((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4),
                              device,
                              dtype,
                              low=op.domain[0],
                              high=op.domain[1])
        contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
        contig2.copy_(contig)

        self.assertTrue(contig.is_contiguous())
        self.assertTrue(contig2.is_contiguous())

        self.assertEqual(op(contig), op(contig2))

    # Tests that computation on a multiple batches is the same as
    # per-batch computation.
    @ops(unary_ufuncs)
    def test_batch_vs_slicing(self, device, dtype, op):
        input = _make_tensor((1024, 512),
                             dtype=dtype,
                             device=device,
                             low=op.domain[0],
                             high=op.domain[1])

        actual = op(input)
        expected = torch.stack([op(slice) for slice in input])

        self.assertEqual(actual, expected)
                           dtype=dtype,
                           requires_grad=requires_grad,
                           low=-1,
                           high=1)
    return [SampleInput(inp, args=((weight, bias) + extra_args))]


additional_op_db.extend([
    OpInfo('nn.functional.conv2d',
           aten_name="conv2d",
           variant_test_name='no_bias',
           supports_autograd=True,
           supports_forward_ad=True,
           sample_inputs_func=partial(sample_inputs_conv2d, False),
           dtypes=floating_types(),
           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
           supports_out=False),
    OpInfo('nn.functional.conv2d',
           aten_name="conv2d",
           variant_test_name='with_bias',
           supports_autograd=True,
           supports_forward_ad=True,
           sample_inputs_func=partial(sample_inputs_conv2d, True),
           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
           dtypes=floating_types(),
           supports_out=False),
    OpInfo('nn.functional.conv2d',
           aten_name="conv2d",
           variant_test_name='stride_with_bias',
           supports_autograd=True,
           supports_forward_ad=True,