示例#1
0
 def wrap_fn(self, *args, **kwargs):
     version = _get_torch_cuda_version()
     if version == (0, 0):  # cpu
         return fn(self, *args, **kwargs)
     if version in (versions or []):
         reason = "test skipped for CUDA version {0}".format(version)
         raise unittest.SkipTest(reason)
     return fn(self, *args, **kwargs)
示例#2
0
def skipCUDAIfNoCusolver(fn):
    version = _get_torch_cuda_version()
    return skipCUDAIf(version < (10, 2), "cuSOLVER not available")(fn)
示例#3
0
class TestPythonJiterator(TestCase):
    @skipCUDAIfRocm
    @parametrize("shape_strides", [
        (([3, 3], [3, 1]), ([3, 3], [3, 1])),  # contiguous
    ])
    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
                     all_types_and_complex_and(torch.half, torch.bfloat16)))
    def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
        a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
        b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])

        a = a_buffer.as_strided(*shape_strides[0])
        b = b_buffer.as_strided(*shape_strides[1])

        expected = ref_fn(a, b)
        result = jitted_fn(a, b)

        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
    @skipCUDAIf(_get_torch_cuda_version() < (11, 6), "On cuda 11.3, nvrtcCompileProgram is taking too long to "
                "compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.")
    @parametrize("shape_strides", [
        (([3, 3], [1, 3]), ([3, 1], [1, 3])),  # non-contiguous
    ])
    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
                     all_types_and_complex_and(torch.half, torch.bfloat16)))
    def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
        a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
        b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])

        a = a_buffer.as_strided(*shape_strides[0])
        b = b_buffer.as_strided(*shape_strides[1])

        expected = ref_fn(a, b)
        result = jitted_fn(a, b)

        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
    @parametrize("alpha", [-1, 2.0, None])
    @parametrize("beta", [3, -4.2, None])
    @toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
    def test_extra_args(self, device, dtype, alpha, beta):
        a = torch.rand(3, device=device).mul(10).type(dtype)
        b = torch.rand(3, device=device).mul(10).type(dtype)

        extra_args = {}
        if alpha is not None:
            extra_args["alpha"] = alpha
        if beta is not None:
            extra_args["beta"] = beta

        expected = ref_fn(a, b, **extra_args)
        result = jitted_fn(a, b, **extra_args)

        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    @parametrize("is_train", [True, False])
    def test_bool_extra_args(self, device, is_train):
        code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
        jitted_fn = create_jit_fn(code_string, is_train=False)

        def ref_fn(x, mask, is_train):
            return x * mask if is_train else x

        a = torch.rand(3, device=device)
        b = torch.rand(3, device=device)

        expected = ref_fn(a, b, is_train=is_train)
        result = jitted_fn(a, b, is_train=is_train)
        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    def test_multiple_functors(self, device):
        code_string = '''
        template <typename T> T fn(T x, T mask) { return x * mask; }
        template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; }
        '''
        jitted_fn = create_jit_fn(code_string)

        def ref_fn(x, mask, y):
            return x * mask + y

        a = torch.rand(3, device=device)
        b = torch.rand(3, device=device)
        c = torch.rand(3, device=device)

        expected = ref_fn(a, b, c)
        result = jitted_fn(a, b, c)
        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    @parametrize("num_inputs", [1, 5, 8])
    def test_various_num_inputs(self, num_inputs):
        inputs = []
        for i in range(num_inputs):
            inputs.append(torch.rand(3, device='cuda').mul(10))

        input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
        function_body = "+".join([f"i{i}" for i in range(num_inputs)])
        code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
        jitted_fn = create_jit_fn(code_string)

        def ref_fn(*inputs):
            return torch.sum(torch.stack(inputs), dim=0)

        expected = ref_fn(*inputs)
        result = jitted_fn(*inputs)

        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    @parametrize("num_outputs", [1, 4, 8])
    def test_various_num_outputs(self, num_outputs):
        input = torch.rand(3, device='cuda')

        output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)])
        function_body = ""
        for i in range(num_outputs):
            function_body += f"out{i} = input + {i};\n"
        code_string = f"template <typename T> T my_kernel(T input, {output_string}) {{ {function_body} }}"

        jitted_fn = create_multi_output_jit_fn(code_string, num_outputs)

        def ref_fn(input):
            outputs = []
            for i in range(num_outputs):
                outputs.append(input + i)

            if num_outputs == 1:
                return outputs[0]
            return tuple(outputs)

        expected = ref_fn(input)
        result = jitted_fn(input)

        for i in range(num_outputs):
            self.assertEqual(expected[i], result[i])

    @skipCUDAIfRocm
    @parametrize("code_string", [
        "template <typename T> T my _kernel(T x) { return x; }",
        "template <typename T> Tmy_kernel(T x) { return x; }",
    ])
    def test_invalid_function_name(self, code_string):
        with self.assertRaises(Exception):
            jitted_fn = create_jit_fn(code_string)
示例#4
0
def _check_cusparse_spgemm_available():
    # cusparseSpGEMM was added in 11.0
    version = _get_torch_cuda_version()
    min_supported_version = (11, 0)
    return version >= min_supported_version
示例#5
0
def _check_cusparse_triangular_solve_available():
    version = _get_torch_cuda_version()
    # cusparseSpSM was added in 11.3.1 but we don't have access to patch version
    min_supported_version = (11, 4)
    return version >= min_supported_version