Exemplo n.º 1
0
 def __init__(self,
              module_cls,  # Class object for the module under test
              *,
              module_inputs_func,  # Function to generate module inputs
              skips=(),  # Indicates which tests to skip
              decorators=None,  # Additional decorators to apply to generated tests
              dtypes=floating_types(),  # dtypes this function is expected to work with
              supports_gradgrad=True,  # whether the op supports second order gradients
              ):
     self.module_cls = module_cls
     self.module_inputs_func = module_inputs_func
     self.skips = skips
     self.decorators = decorators
     self.dtypes = dtypes
     self.supports_gradgrad = supports_gradgrad
Exemplo n.º 2
0
    def _parametrize_test(self, test, generic_cls, device_cls):
        for module_info in self.module_info_list:
            # TODO: Factor some of this out since it's similar to OpInfo.
            for dtype in floating_types():
                # Construct the test name.
                test_name = '{}_{}_{}{}'.format(
                    test.__name__, module_info.name.replace('.', '_'),
                    device_cls.device_type, _dtype_test_suffix(dtype))

                # Construct parameter kwargs to pass to the test.
                param_kwargs = {'module_info': module_info}
                _update_param_kwargs(param_kwargs, 'dtype', dtype)

                try:
                    active_decorators = []
                    if module_info.should_skip(generic_cls.__name__,
                                               test.__name__,
                                               device_cls.device_type, dtype):
                        active_decorators.append(skipIf(True, "Skipped!"))

                    if module_info.decorators is not None:
                        for decorator in module_info.decorators:
                            # Can't use isinstance as it would cause a circular import
                            if decorator.__class__.__name__ == 'DecorateInfo':
                                if decorator.is_active(generic_cls.__name__,
                                                       test.__name__,
                                                       device_cls.device_type,
                                                       dtype):
                                    active_decorators += decorator.decorators
                            else:
                                active_decorators.append(decorator)

                    @wraps(test)
                    def test_wrapper(*args, **kwargs):
                        return test(*args, **kwargs)

                    for decorator in active_decorators:
                        test_wrapper = decorator(test_wrapper)

                    yield (test_wrapper, test_name, param_kwargs)
                except Exception as ex:
                    # Provides an error message for debugging before rethrowing the exception
                    print("Failed to instantiate {0} for module {1}!".format(
                        test_name, module_info.name))
                    raise ex
Exemplo n.º 3
0
 def test_sparse_csr_print(self, device):
     orig_maxDiff = self.maxDiff
     self.maxDiff = None
     shape_nnz = [((10, 10), 10), ((100, 10), 10), ((1000, 10), 10)]
     printed = []
     for shape, nnz in shape_nnz:
         values_shape = torch.Size((nnz, ))
         col_indices_shape = torch.Size((nnz, ))
         crow_indices_shape = torch.Size((shape[0] + 1, ))
         printed.append("# shape: {}".format(torch.Size(shape)))
         printed.append("# nnz: {}".format(nnz))
         printed.append(
             "# crow_indices shape: {}".format(crow_indices_shape))
         printed.append("# col_indices shape: {}".format(col_indices_shape))
         printed.append("# values_shape: {}".format(values_shape))
         for index_dtype in [torch.int32, torch.int64]:
             for dtype in floating_types():
                 printed.append("########## {}/{} ##########".format(
                     dtype, index_dtype))
                 x = torch.sparse_csr_tensor(
                     torch.tensor([0, 2, 4], dtype=index_dtype),
                     torch.tensor([0, 1, 0, 1], dtype=index_dtype),
                     torch.tensor([1, 2, 3, 4]),
                     dtype=dtype,
                     device=device)
                 printed.append("# sparse tensor")
                 printed.append(str(x))
                 printed.append("# _crow_indices")
                 printed.append(str(x.crow_indices()))
                 printed.append("# _col_indices")
                 printed.append(str(x.col_indices()))
                 printed.append("# _values")
                 printed.append(str(x.values()))
                 printed.append('')
             printed.append('')
     self.assertExpected('\n'.join(printed))
     self.maxDiff = orig_maxDiff
Exemplo n.º 4
0
class TestForeach(TestCase):
    @property
    def is_cuda(self):
        return self.device_type == 'cuda'

    # note(mkozuki): It might be the case that the expected number of `cudaLaunchKernel`s
    # is greater than 1 once foreach functions internally separate their input `TensorList`s by
    # devices & dtypes into vectors of tensors.
    def _get_funcs(self, op, n_expected_cudaLaunchKernels: int):
        return (
            ForeachFuncWrapper(op.method_variant,
                               n_expected_cudaLaunchKernels),
            RegularFuncWrapper(op.ref),
            ForeachFuncWrapper(op.inplace_variant,
                               n_expected_cudaLaunchKernels),
            RegularFuncWrapper(op.ref_inplace),
        )

    def _binary_test(self,
                     dtype,
                     op,
                     ref,
                     inputs,
                     is_fastpath,
                     is_inplace,
                     *,
                     alpha=None):
        ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]
                      ] if is_inplace else inputs
        try:
            actual = op(inputs, self.is_cuda, is_fastpath)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
                ref(ref_inputs)
        else:
            expected = ref(ref_inputs)
            self.assertEqual(actual, expected)
        if alpha is not None:
            kwargs = {'alpha': alpha}
            ref_inputs = inputs
            try:
                actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
            except RuntimeError as e:
                with self.assertRaisesRegex(type(e), re.escape(str(e))):
                    ref(ref_inputs, **kwargs)
            else:
                expected = ref(ref_inputs, **kwargs)
                if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
                    self.assertEqual(expected,
                                     actual,
                                     atol=1.e-3,
                                     rtol=default_tolerances(dtype)[0])
                else:
                    self.assertEqual(expected, actual)

    def _test_binary_op_tensorlists(self, device, dtype, opinfo, N,
                                    is_fastpath, disable_fastpath):
        n_expected_cudaLaunchKernels = N if disable_fastpath else 1
        op, ref, inplace_op, inplace_ref = self._get_funcs(
            opinfo, n_expected_cudaLaunchKernels)
        inputs = [
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath),
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath),
        ]
        self._binary_test(dtype,
                          op,
                          ref,
                          inputs,
                          is_fastpath,
                          is_inplace=False)
        self._binary_test(dtype,
                          inplace_op,
                          inplace_ref,
                          inputs,
                          is_fastpath,
                          is_inplace=True)
        if opinfo.supports_alpha_param:
            alpha = None
            if dtype in integral_types():
                alpha = 3
            elif dtype.is_complex:
                alpha = complex(3, 3)
            else:
                alpha = 3.14
            self._binary_test(dtype,
                              op,
                              ref,
                              inputs,
                              is_fastpath,
                              is_inplace=False,
                              alpha=alpha)
            self._binary_test(dtype,
                              inplace_op,
                              inplace_ref,
                              inputs,
                              is_fastpath,
                              is_inplace=True,
                              alpha=alpha)

        # Tests of implicit broadcasting
        # When sizes of tensors don't match, foreach functions are supposed to choose slow path
        # even if this methods's argument `is_fastpath` is True.
        # `cudaLaunchKernel` will be equal to `N`. For assert in `ForeachFuncWrapper` to pass,
        # we pass `is_fastpath and disable_fastpath` to `_binary_test`'s argument of is_fastpath.
        # as n_expected_cudaLaunchKernels is N if disable_fastpath.
        inputs = [
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath),
            [
                make_tensor((N - i, 1),
                            device=device,
                            dtype=dtype,
                            noncontiguous=not is_fastpath) for i in range(N)
            ],
        ]
        self._binary_test(dtype,
                          op,
                          ref,
                          inputs,
                          is_fastpath and disable_fastpath,
                          is_inplace=False)
        self._binary_test(dtype,
                          inplace_op,
                          inplace_ref,
                          inputs,
                          is_fastpath and disable_fastpath,
                          is_inplace=True)

    @skipMeta
    @ops(foreach_binary_op_db)
    def test_binary_op_tensorlists_fastpath(self, device, dtype, op):
        for N in N_values:
            disable_fastpath = op.ref == torch.div and dtype in integral_types_and(
                torch.bool)
            if op.ref == torch.add and dtype == torch.bool:
                disable_fastpath = True
            self._test_binary_op_tensorlists(device, dtype, op, N, True,
                                             disable_fastpath)

    @ops(foreach_binary_op_db)
    def test_binary_op_tensorlists_slowpath(self, device, dtype, op):
        for N in N_values:
            self._test_binary_op_tensorlists(device, dtype, op, N, False,
                                             False)

    def _test_binary_op_scalar(self, device, dtype, opinfo, N, scalar,
                               is_fastpath, disable_fastpath):
        n_expected_cudaLaunchKernels = N if disable_fastpath else 1
        op, ref, inplace_op, inplace_ref = self._get_funcs(
            opinfo, n_expected_cudaLaunchKernels)
        inputs = [
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath), scalar
        ]
        self._binary_test(dtype,
                          op,
                          ref,
                          inputs,
                          is_fastpath,
                          is_inplace=False)
        self._binary_test(dtype,
                          inplace_op,
                          inplace_ref,
                          inputs,
                          is_fastpath,
                          is_inplace=True)

    @skipMeta
    @ops(foreach_binary_op_db)
    def test_binary_op_scalar_fastpath(self, device, dtype, op):
        for N, scalar in itertools.product(N_values, Scalars):
            disable_fastpath = op.ref == torch.div and dtype in integral_types_and(
                torch.bool)
            if isinstance(scalar, int):
                disable_fastpath |= dtype == torch.bool
            if isinstance(scalar, float):
                disable_fastpath |= dtype in integral_types_and(torch.bool)
            if isinstance(scalar, bool):
                disable_fastpath |= dtype == torch.bool
                if op.ref in (torch.add, torch.mul):
                    disable_fastpath = False
            if isinstance(scalar, complex):
                disable_fastpath |= dtype not in complex_types()
            self._test_binary_op_scalar(device, dtype, op, N, scalar, True,
                                        disable_fastpath)

    @ops(foreach_binary_op_db)
    def test_binary_op_scalar_slowpath(self, device, dtype, op):
        for N, scalar in itertools.product(N_values, Scalars):
            self._test_binary_op_scalar(device, dtype, op, N, scalar, False,
                                        False)

    def _test_binary_op_scalarlist(self, device, dtype, opinfo, N, scalarlist,
                                   is_fastpath, disable_fastpath):
        n_expected_cudaLaunchKernels = N if disable_fastpath else 1
        op, ref, inplace_op, inplace_ref = self._get_funcs(
            opinfo, n_expected_cudaLaunchKernels)
        inputs = [
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath), scalarlist
        ]
        self._binary_test(dtype,
                          op,
                          ref,
                          inputs,
                          is_fastpath,
                          is_inplace=False)
        self._binary_test(dtype,
                          inplace_op,
                          inplace_ref,
                          inputs,
                          is_fastpath,
                          is_inplace=True)

    # note(mkozuki): Why two functions depending on with/without bool?
    # `foreach_sub` & `foreach_sub_` do `sub_check(tensors[i], scalars[i])` from i=1...N.
    # So, if scalarlist has one or more bool values, `foreach_sub` and `foreach_sub_`
    # raise bool subtraction error before doing any math.
    # While regular `sub` and `sub_` do some math until they encounter bool.
    # So, foreach sub's throw bool sub error first. However, regular sub's throw different
    # errors depending on the order of scalarlist. To keep actual unit test impl simple,
    # separating mixed scalarlist tests. By setting the first element of scalarlist to bool,
    # they are expected to throw bool sub error even in inplace test.
    @skipMeta
    @ops(foreach_binary_op_db)
    def test_binary_op_scalarlist_fastpath(self, device, dtype, op):
        for N in N_values:
            for type_str, scalarlist in getScalarLists(N):
                bool_int_div = op.ref == torch.div and dtype in integral_types_and(
                    torch.bool)
                disable_fastpath = bool_int_div
                if type_str == "int":
                    disable_fastpath |= dtype == torch.bool
                if type_str == "float":
                    disable_fastpath |= dtype in integral_types_and(torch.bool)
                if type_str == "complex":
                    disable_fastpath |= dtype not in complex_types()
                if type_str == "mixed":
                    disable_fastpath |= True and dtype not in complex_types()
                self._test_binary_op_scalarlist(device, dtype, op, N,
                                                scalarlist, True,
                                                disable_fastpath)

    @ops(foreach_binary_op_db)
    def test_binary_op_scalarlist_slowpath(self, device, dtype, op):
        for N in N_values:
            for _, scalarlist in getScalarLists(N):
                self._test_binary_op_scalarlist(device, dtype, op, N,
                                                scalarlist, False, False)

    def _pointwise_test(self,
                        dtype,
                        op,
                        ref,
                        inputs,
                        is_fastpath,
                        is_inplace,
                        *,
                        values=None):
        ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1],
                      inputs[2]] if is_inplace else inputs
        try:
            actual = op(inputs, self.is_cuda, is_fastpath)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
                ref(ref_inputs)
        else:
            expected = ref(ref_inputs)
            self.assertEqual(expected, actual)
        if values is not None:
            try:
                actual = op(inputs + [values], self.is_cuda, is_fastpath)
            except RuntimeError as e:
                with self.assertRaisesRegex(type(e), re.escape(str(e))):
                    ref(ref_inputs, values=values)
            else:
                expected = ref(ref_inputs, values=values)
                self.assertEqual(expected, actual)

    def _test_pointwise_op(self,
                           device,
                           dtype,
                           opinfo,
                           N,
                           is_fastpath,
                           disable_fastpath,
                           *,
                           values=None):
        n_expected_cudaLaunchKernels = N if disable_fastpath else 1
        op, ref, inplace_op, inplace_ref = self._get_funcs(
            opinfo, n_expected_cudaLaunchKernels)
        inputs = [
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath),
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath),
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath),
        ]
        self._pointwise_test(dtype,
                             op,
                             ref,
                             inputs,
                             is_fastpath,
                             is_inplace=False,
                             values=values)
        self._pointwise_test(dtype,
                             inplace_op,
                             inplace_ref,
                             inputs,
                             is_fastpath,
                             is_inplace=True,
                             values=values)

        # Tests of implicit broadcasting
        inputs = [
            opinfo.sample_inputs(device,
                                 dtype,
                                 N,
                                 noncontiguous=not is_fastpath,
                                 same_size=True),
            [
                make_tensor((N - i, 1),
                            device=device,
                            dtype=dtype,
                            noncontiguous=not is_fastpath) for i in range(N)
            ],
            [
                make_tensor((1, N - i),
                            device=device,
                            dtype=dtype,
                            noncontiguous=not is_fastpath) for i in range(N)
            ],
        ]
        self._pointwise_test(dtype,
                             op,
                             ref,
                             inputs,
                             is_fastpath and disable_fastpath,
                             is_inplace=False,
                             values=values)
        self._pointwise_test(dtype,
                             inplace_op,
                             inplace_ref,
                             inputs,
                             is_fastpath and disable_fastpath,
                             is_inplace=True,
                             values=values)

    @skipMeta
    @ops(foreach_pointwise_op_db)
    def test_pointwise_op_fastpath(self, device, dtype, op):
        disable_fastpath = dtype in integral_types_and(torch.bool)
        # for N, scalar in itertools.product(N_values, Scalars):
        for N in N_values:
            self._test_pointwise_op(device, dtype, op, N, True,
                                    disable_fastpath)
            for scalar in Scalars:
                self._test_pointwise_op(device,
                                        dtype,
                                        op,
                                        N,
                                        True,
                                        disable_fastpath,
                                        values=scalar)
            for _, scalarlist in getScalarLists(N):
                self._test_pointwise_op(device,
                                        dtype,
                                        op,
                                        N,
                                        True,
                                        disable_fastpath,
                                        values=scalarlist)

    @ops(foreach_pointwise_op_db)
    def test_pointwise_op_slowpath(self, device, dtype, op):
        # for N, scalar in itertools.product(N_values, Scalars):
        for N in N_values:
            self._test_pointwise_op(device, dtype, op, N, False, False)
            for scalar in Scalars:
                self._test_pointwise_op(device,
                                        dtype,
                                        op,
                                        N,
                                        False,
                                        False,
                                        values=scalar)
            for _, scalarlist in getScalarLists(N):
                self._test_pointwise_op(device,
                                        dtype,
                                        op,
                                        N,
                                        False,
                                        False,
                                        values=scalarlist)

    # note(mkozuki): fastpath test uses dtypes which fastpath implementation supports.
    # To confirm the dtypes of `OpInfo` cover the dtypes that the function support,
    # this test does not use `try-except` for fastpath.
    def _regular_unary_test(self, dtype, op, ref, inputs, is_fastpath):
        if is_fastpath:
            self.assertEqual(ref(inputs), op(inputs, self.is_cuda,
                                             is_fastpath))
            return
        try:
            actual = op(inputs, self.is_cuda, is_fastpath)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
                ref(inputs)
        else:
            expected = ref(inputs)
            self.assertEqual(actual, expected)

    # note(mkozuki): why `try-except` for both fastpath?
    # - inputs for fastpath can be integer tensors.
    #    - this is becase opinfo dtypes are configured for outpulace implementation
    # - for integer inputs, trigonometric functions and exponential function returns float outputs,
    #   which causes "result type Float can't be case to the desired type" error.
    # Thus, `try-except` is used even if `is_fastpath` is `True`.
    def _inplace_unary_test(self, dtype, inplace, inplace_ref, inputs,
                            is_fastpath):
        copied_inputs = [[t.clone().detach() for t in tensors]
                         for tensors in inputs]
        try:
            inplace(inputs, self.is_cuda, is_fastpath)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
                inplace_ref(copied_inputs)
        else:
            inplace_ref(copied_inputs),
            self.assertEqual(copied_inputs, inputs)

    def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
        op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
        inputs = opinfo.sample_inputs(device,
                                      dtype,
                                      N,
                                      noncontiguous=not is_fastpath),
        # note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
        if opinfo.name == "_foreach_abs" and dtype in complex_types():
            is_fastpath = False
        self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
        self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs,
                                 is_fastpath)

    @skipMeta
    @ops(foreach_unary_op_db)
    def test_unary_fastpath(self, device, dtype, op):
        for N in N_values:
            self._test_unary(device, dtype, op, N, is_fastpath=True)

    @ops(foreach_unary_op_db,
         dtypes=all_types_and_complex_and(torch.half, torch.bfloat16,
                                          torch.bool))
    def test_unary_slowpath(self, device, dtype, op):
        for N in N_values:
            self._test_unary(device, dtype, op, N, is_fastpath=False)

    # note(crcrpar): `torch.maximum` and `torch.minimum` support `out` arg but there seem to be no inplace versions.
    # So, compare `inplace_op` results with `ref`'s outputs.
    def _minmax_test(self, opinfo, inputs, is_fastpath,
                     n_expected_cudaLaunchKernels):
        op, ref, inplace_op, _ = self._get_funcs(opinfo,
                                                 n_expected_cudaLaunchKernels)
        expected = ref(inputs)
        self.assertEqual(expected, op(inputs, self.is_cuda, is_fastpath))

        inplace_inputs = [[t.clone() for t in inputs[0]], inputs[1]]
        inplace_op(inplace_inputs, self.is_cuda, is_fastpath)
        self.assertEqual(expected, inplace_inputs[0])

    @ops(foreach_minmax_op_db)
    def test_minmax_fastpath(self, device, dtype, op):
        for N in N_values:
            inputs = tuple(
                op.sample_inputs(device, dtype, N) for _ in range(2))
            self._minmax_test(op, inputs, True,
                              N if dtype == torch.bool else 1)

    @ops(foreach_minmax_op_db,
         dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool))
    def test_minmax_slowpath(self, device, dtype, op):
        for N in N_values:
            inputs = tuple(
                op.sample_inputs(device, dtype, N, noncontiguous=True)
                for _ in range(2))
            self._minmax_test(op, inputs, False, 1)

    # note(mkozuki): ForeachFuncInfo's of both `_foreach_maximum` and `_foreach_minimum` include integer types.
    # so, manually limit dtypes to fp types for inf&nan tests.
    @ops(foreach_minmax_op_db,
         dtypes=floating_types_and(torch.half, torch.bfloat16))
    def test_minmax_float_inf_nan(self, device, dtype, op):
        inputs = (
            [
                torch.tensor([float('inf')], device=device, dtype=dtype),
                torch.tensor([-float('inf')], device=device, dtype=dtype),
                torch.tensor([float('nan')], device=device, dtype=dtype),
                torch.tensor([float('nan')], device=device, dtype=dtype)
            ],
            [
                torch.tensor([-float('inf')], device=device, dtype=dtype),
                torch.tensor([float('inf')], device=device, dtype=dtype),
                torch.tensor([float('inf')], device=device, dtype=dtype),
                torch.tensor([float('nan')], device=device, dtype=dtype)
            ],
        )
        self._minmax_test(op, inputs, True, 1)

    def _reduce_test(self, opinfo, inputs, ord, is_fastpath,
                     n_expected_cudaLaunchKernels):
        op, ref, _, _ = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
        self.assertEqual(ref(inputs, ord=ord),
                         op(inputs, self.is_cuda, is_fastpath, ord=ord))

    @ops(foreach_reduce_op_db)
    def test_reduce_fastpath(self, device, dtype, op):
        for N, ord in itertools.product(N_values, (0, 1, 2, -1, -2)):
            if ord in (1, 2) and dtype in floating_types_and(
                    torch.half, torch.bfloat16):
                n_expected_cudaLaunchKernels = 3
            else:
                n_expected_cudaLaunchKernels = N
            inputs = op.sample_inputs(device, dtype, N, noncontiguous=False),
            self._reduce_test(op, inputs, ord, True,
                              n_expected_cudaLaunchKernels)

    @ops(foreach_reduce_op_db)
    def test_reduce_slowpath(self, device, dtype, op):
        for N, ord in itertools.product(N_values, (0, 1, 2, -1, -2)):
            inputs = op.sample_inputs(device, dtype, N, noncontiguous=True),
            self._reduce_test(op, inputs, ord, False, 1)

    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
    def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
        # TODO: enable empty list case
        for tensors in [[torch.randn([0])]]:
            res = torch._foreach_add(tensors, 1)
            self.assertEqual(res, tensors)

            torch._foreach_add_(tensors, 1)
            self.assertEqual(res, tensors)

    @ops(foreach_binary_op_db,
         dtypes=all_types_and_complex_and(torch.half, torch.bfloat16,
                                          torch.bool))
    def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype,
                                                       op):
        foreach_op, ref = op.method_variant, op.ref
        tensors = [
            torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)
        ]

        if ref == torch.sub and dtype == torch.bool:
            with self.assertRaisesRegex(RuntimeError,
                                        re.escape(_BOOL_SUB_ERR_MSG)):
                [ref(t, 1) for t in tensors]
            with self.assertRaisesRegex(RuntimeError,
                                        re.escape(_BOOL_SUB_ERR_MSG)):
                foreach_op(tensors, 1)
            return

        expected = [ref(t, 1) for t in tensors]
        res = foreach_op(tensors, 1)
        self.assertEqual(res, expected)

    # note(mkozuki): this test case fails with Meta at least in my local environment.
    # The message was
    # `AssertionError: NotImplementedError("Could not run 'aten::_foreach_add.Scalar' with arguments from the 'Meta' backend.`
    @skipMeta
    @ops(foreach_binary_op_db, allowed_dtypes=[torch.float])
    def test_binary_op_scalar_with_different_tensor_dtypes(
            self, device, dtype, op):
        foreach_op = op.method_variant
        tensors = [
            torch.tensor([1.1], dtype=torch.float, device=device),
            torch.tensor([1], dtype=torch.long, device=device)
        ]
        runtime_error = None
        try:
            foreach_op(tensors, 1)
        except RuntimeError as e:
            runtime_error = e
        self.assertIsNone(runtime_error)

    @ops(foreach_binary_op_db,
         dtypes=all_types_and_complex_and(torch.half, torch.bfloat16,
                                          torch.bool))
    def test_binary_op_list_error_cases(self, device, dtype, op):
        foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
        tensors1 = []
        tensors2 = []

        # Empty lists
        with self.assertRaisesRegex(
                RuntimeError,
                "There were no tensor arguments to this function"):
            foreach_op(tensors1, tensors2)
        with self.assertRaisesRegex(
                RuntimeError,
                "There were no tensor arguments to this function"):
            foreach_op_(tensors1, tensors2)

        # One empty list
        tensors1.append(torch.tensor([1], device=device, dtype=dtype))
        with self.assertRaisesRegex(
                RuntimeError,
                "Tensor list must have same number of elements as scalar list."
        ):
            foreach_op(tensors1, tensors2)
        with self.assertRaisesRegex(
                RuntimeError,
                "Tensor list must have same number of elements as scalar list."
        ):
            foreach_op_(tensors1, tensors2)

        # Lists have different amount of tensors
        tensors2.append(torch.tensor([1], device=device))
        tensors2.append(torch.tensor([1], device=device))
        with self.assertRaisesRegex(
                RuntimeError,
                "Tensor lists must have the same number of tensors, got 1 and 2"
        ):
            foreach_op(tensors1, tensors2)
        with self.assertRaisesRegex(
                RuntimeError,
                "Tensor lists must have the same number of tensors, got 1 and 2"
        ):
            foreach_op_(tensors1, tensors2)

        # Corresponding tensors with different sizes that aren't compatible with broadcast
        # If sizes are different then foreach chooses slow path, thus error messages are expected
        # to be the same as torch regular function.
        tensors1 = [
            torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)
        ]
        tensors2 = [
            torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)
        ]
        try:
            foreach_op(tensors1, tensors2)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
                [ref(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
        try:
            foreach_op_(tensors1, tensors2)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
                [ref_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]

        # different devices
        if self.device_type == "cuda" and torch.cuda.device_count() > 1:
            tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
            tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
            if dtype == torch.bool and foreach_op == torch._foreach_sub:
                with self.assertRaisesRegex(RuntimeError,
                                            re.escape(_BOOL_SUB_ERR_MSG)):
                    foreach_op([tensor1], [tensor2])
                with self.assertRaisesRegex(RuntimeError,
                                            re.escape(_BOOL_SUB_ERR_MSG)):
                    foreach_op_([tensor1], [tensor2])
                return
            with self.assertRaisesRegex(
                    RuntimeError,
                    "Expected all tensors to be on the same device"):
                foreach_op([tensor1], [tensor2])
            if dtype in integral_types_and(
                    torch.bool) and foreach_op == torch._foreach_div:
                with self.assertRaisesRegex(RuntimeError, "result type"):
                    foreach_op_([tensor1], [tensor2])
            else:
                with self.assertRaisesRegex(
                        RuntimeError,
                        "Expected all tensors to be on the same device"):
                    foreach_op_([tensor1], [tensor2])

    @skipMeta
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
    @ops(foreach_binary_op_db,
         dtypes=all_types_and_complex_and(torch.half, torch.bfloat16,
                                          torch.bool))
    def test_binary_op_list_slow_path(self, device, dtype, op):
        # note(mkozuki): why `n_expected_cudaLaunchKernels=0`?
        # In this test, foreach functions don't go through fast path,
        # but as there is only one tensor in each list of tensors,
        # `cudaLaunchKernel` is 1 so ForeachFuncWrapper internal assert fails.
        foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(
            op, n_expected_cudaLaunchKernels=0)
        # 0-strides
        tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
        tensor2 = make_tensor((1, ), device=device,
                              dtype=dtype).expand_as(tensor1)
        inputs = ([tensor1], [tensor2])
        self._binary_test(dtype,
                          foreach_op,
                          native_op,
                          inputs,
                          is_fastpath=False,
                          is_inplace=False)
        self._binary_test(dtype,
                          foreach_op_,
                          native_op_,
                          inputs,
                          is_fastpath=False,
                          is_inplace=True)

        # different strides
        tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
        tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
        inputs = ([tensor1], [tensor2.t()])
        self._binary_test(dtype,
                          foreach_op,
                          native_op,
                          inputs,
                          is_fastpath=False,
                          is_inplace=False)
        self._binary_test(dtype,
                          foreach_op_,
                          native_op_,
                          inputs,
                          is_fastpath=False,
                          is_inplace=True)

        # non contiguous
        tensor1 = make_tensor((5, 2, 1, 3),
                              device=device,
                              dtype=dtype,
                              noncontiguous=True)
        tensor2 = make_tensor((5, 2, 1, 3),
                              device=device,
                              dtype=dtype,
                              noncontiguous=True)
        self.assertFalse(tensor1.is_contiguous())
        self.assertFalse(tensor2.is_contiguous())
        inputs = ([tensor1], [tensor2])
        self._binary_test(dtype,
                          foreach_op,
                          native_op,
                          inputs,
                          is_fastpath=False,
                          is_inplace=False)
        self._binary_test(dtype,
                          foreach_op_,
                          native_op_,
                          inputs,
                          is_fastpath=False,
                          is_inplace=True)

        # sliced tensor
        tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
        tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device,
                              dtype=dtype)[:, :, :, ::7]
        inputs = ([tensor1], [tensor2])
        self._binary_test(dtype,
                          foreach_op,
                          native_op,
                          inputs,
                          is_fastpath=False,
                          is_inplace=False)
        self._binary_test(dtype,
                          foreach_op_,
                          native_op_,
                          inputs,
                          is_fastpath=False,
                          is_inplace=True)

    # note: Below three tests (postfixed with `_tensors_on_different_devices`)
    # checks whether foreach works with lists of tensors on different devices
    # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
    @onlyCUDA
    @ops(foreach_unary_op_db)
    def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
        method, ref, inplace_method, ref_inplace = self._get_funcs(op, 1)
        # tensors: ['cuda', 'cpu]
        tensors = op.sample_inputs(device, dtype, 2)
        tensors[1] = tensors[1].to('cpu')
        try:
            actual = method((tensors, ), False, False)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), str(e)):
                ref((tensors, ))
        else:
            expected = ref((tensors, ))
            self.assertEqual(expected, actual)

        try:
            inplace_method((tensors, ), False, False)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), str(e)):
                ref_inplace((tensors, ))
        else:
            self.assertEqual(expected, tensors)

    @onlyCUDA
    @ops(foreach_binary_op_db)
    def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
        # `tensors1`: ['cuda', 'cpu']
        # `tensors2`: ['cuda', 'cpu']
        _cuda_tensors = op.sample_inputs(device, dtype, 2, same_size=True)
        _cpu_tensors = op.sample_inputs('cpu', dtype, 2, same_size=True)
        tensors1, tensors2 = list(
            tensors for tensors in zip(_cuda_tensors, _cpu_tensors))

        foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
        native_op, native_op_ = op.ref, op.ref_inplace
        try:
            actual = foreach_op(tensors1, tensors2)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
                [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
        else:
            expected = [
                native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)
            ]
            self.assertEqual(expected, actual)
        try:
            foreach_op_(tensors1, tensors2)
        except RuntimeError as e:
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
                [native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
        else:
            self.assertEqual(actual, tensors1)

    @onlyCUDA
    @ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
    def test_pointwise_op_tensors_on_different_devices(self, device, dtype,
                                                       op):
        # tensors1: ['cuda', 'cpu]
        # tensors2: ['cuda', 'cpu]
        # tensors3: ['cuda', 'cpu]
        _cuda_tensors = op.sample_inputs(device, dtype, 3, same_size=True)
        _cpu_tensors = op.sample_inputs('cpu', dtype, 3, same_size=True)
        tensors1, tensors2, tensors3 = list(
            tensors for tensors in zip(_cuda_tensors, _cpu_tensors))

        foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref
        actual = foreach_op(tensors1, tensors2, tensors3)
        expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
        self.assertEqual(expected, actual)

        # note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
        foreach_op_(tensors1, tensors2, tensors3)
        self.assertEqual(expected, tensors1)

    # note: BFloat16 has the same number of exponent bits as FP32
    # so if squared L2 norm overflows in BF16, then it also overflows in FP32.
    @onlyCUDA
    @ops(foreach_reduce_op_db, allowed_dtypes=(torch.half, torch.bfloat16))
    def test_foreach_l2_large_value_input(self, device, dtype, op):
        ord, N = 2, 10
        max_value = torch.finfo(dtype).max
        scaler = torch.tensor([max_value]).sqrt().to(device=device,
                                                     dtype=dtype)
        inputs = [
            t * scaler for t in op.sample_inputs(
                device, dtype, N, noncontiguous=False, low=1)
        ],
        # make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
        self.assertTrue(scaler * scaler * N > max_value)
        fn, ref_fn, *_ = self._get_funcs(op, 3)
        actual = fn(inputs, is_cuda=True, is_fastpath=True, ord=ord)
        expect = ref_fn(inputs, ord=ord)
        if dtype == torch.float16:
            # making sure the reference L2 norm values are in the range of FP16.
            self.assertFalse(any(torch.isinf(e) for e in expect))
        else:
            self.assertTrue(all(torch.isinf(e) for e in expect))
        self.assertEqual(expect, actual, equal_nan=False)
Exemplo n.º 5
0
class TestSparseCSR(TestCase):
    @onlyCPU
    def test_csr_layout(self):
        self.assertEqual(str(torch.sparse_csr), 'torch.sparse_csr')
        self.assertEqual(type(torch.sparse_csr), torch.layout)

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_constructor_shape_inference(self, device, dtype):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices,
                                                      dtype=torch.int64),
                                         torch.tensor(col_indices,
                                                      dtype=torch.int64),
                                         torch.tensor(values),
                                         dtype=dtype,
                                         device=device)
        self.assertEqual(torch.tensor(crow_indices, dtype=torch.int64),
                         sparse.crow_indices())
        self.assertEqual((len(crow_indices) - 1, max(col_indices) + 1),
                         sparse.shape)
        self.assertEqual(dtype, sparse.dtype)
        self.assertEqual(torch.device(device), sparse.device)

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_constructor(self, device, dtype):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        for index_dtype in [torch.int32, torch.int64]:
            sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices,
                                                          dtype=index_dtype),
                                             torch.tensor(col_indices,
                                                          dtype=index_dtype),
                                             torch.tensor(values),
                                             size=(2, 10),
                                             dtype=dtype,
                                             device=device)
            self.assertEqual((2, 10), sparse.shape)
            self.assertEqual(torch.tensor(crow_indices, dtype=index_dtype),
                             sparse.crow_indices())
            self.assertEqual(torch.tensor(col_indices, dtype=index_dtype),
                             sparse.col_indices())
            self.assertEqual(torch.tensor(values, dtype=dtype),
                             sparse.values())

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_constructor_from_lists(self, device, dtype):
        # without size
        sparse = torch.sparse_csr_tensor([0, 2, 4], [0, 1, 0, 1], [1, 2, 3, 4],
                                         dtype=dtype,
                                         device=device)

        self.assertEqual((2, 2), sparse.shape)
        self.assertEqual(4, sparse.numel())
        self.assertEqual(
            torch.tensor([0, 2, 4], dtype=torch.int64, device=device),
            sparse.crow_indices())
        self.assertEqual(
            torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device),
            sparse.col_indices())
        self.assertEqual(
            torch.tensor([1, 2, 3, 4], dtype=dtype, device=device),
            sparse.values())

        # with size
        for sparse_csr_tensor in [
                torch.sparse_csr_tensor, torch._sparse_csr_tensor_unsafe
        ]:
            sparse = sparse_csr_tensor([0, 2, 4], [0, 1, 0, 1], [1, 2, 3, 4],
                                       size=(2, 10),
                                       dtype=dtype,
                                       device=device)

            self.assertEqual((2, 10), sparse.shape)
            self.assertEqual(
                torch.tensor([0, 2, 4], dtype=torch.int64, device=device),
                sparse.crow_indices())
            self.assertEqual(
                torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device),
                sparse.col_indices())
            self.assertEqual(
                torch.tensor([1, 2, 3, 4], dtype=dtype, device=device),
                sparse.values())

    def test_factory_type_invariants_check(self, device):
        with self.assertRaisesRegex(
                RuntimeError,
                "both crow_indices and col_indices should have the same type."
        ):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int64),
                                    torch.tensor([0, 1, 0, 1],
                                                 dtype=torch.int32),
                                    torch.tensor([1, 2, 3, 4]),
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"\"csr_construct_check\" not implemented for 'Short'"):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int16),
                                    torch.tensor([0, 1, 0, 1],
                                                 dtype=torch.int16),
                                    torch.tensor([1, 2, 3, 4]),
                                    device=device)

    def test_factory_layout_invariants_check(self, device):
        with self.assertRaisesRegex(
                RuntimeError,
                "expected values to be a strided and contiguous tensor"):
            values = torch.tensor([1.], device=device).expand(4, )
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], device=device),
                                    torch.tensor([0, 1, 0, 1], device=device),
                                    values)

        with self.assertRaisesRegex(
                RuntimeError,
                "expected col_indices to be a strided and contiguous tensor"):
            col_indices = torch.tensor([0], device=device).expand(4, )
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4]), col_indices,
                                    torch.tensor([1, 2, 3, 4]))

        with self.assertRaisesRegex(
                RuntimeError,
                "expected crow_indices to be a strided and contiguous tensor"):
            crow_indices = torch.arange(6, device=device)
            torch.sparse_csr_tensor(crow_indices[::2],
                                    torch.tensor([0, 1, 0, 1], device=device),
                                    torch.tensor([1, 2, 3, 4]))

    def test_factory_shape_invariants_check(self, device):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        size = (2, 10)
        torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                torch.tensor(col_indices),
                                torch.tensor(values),
                                size,
                                device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"size of a CSR tensor must be of length 2, but got: 3"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices),
                                    torch.tensor(values),
                                    size=(2, 10, 2),
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"crow_indices must have dim\=1 but got crow_indices\.dim\(\)\=2"
        ):
            torch.sparse_csr_tensor(torch.tensor(crow_indices).repeat(2, 1),
                                    torch.tensor(col_indices),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"col_indices must have dim\=1 but got col_indices\.dim\(\)\=2"
        ):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices).repeat(2, 1),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"values must have dim\=1 but got values\.dim\(\)\=2"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices),
                                    torch.tensor(values).repeat(2, 1),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices),
                                    torch.tensor(values), (1, 1),
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"col_indices and values must have equal sizes, " +
                r"but got col_indices\.numel\(\): 3, values\.numel\(\): 4"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor([0, 1, 0]),
                                    torch.tensor(values),
                                    size,
                                    device=device)

    def test_factory_indices_invariants_check(self, device):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        size = (2, 10)
        with self.assertRaisesRegex(RuntimeError,
                                    "0th value of crow_indices must be 0."):
            torch.sparse_csr_tensor(torch.tensor([-1, 0, 4]),
                                    torch.tensor(col_indices),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                "last value of crow_indices should be equal to the length of col_indices."
        ):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 5]),
                                    torch.tensor(col_indices),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError, r"at position i \= 2," +
                r" this condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"
        ):
            torch.sparse_csr_tensor(torch.tensor([0, 5, 4]),
                                    torch.tensor(col_indices),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"col_indices\.min\(\) should be greater or equal to zero"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor([0, -1, 0, 1]),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"size\(1\) should be greater than col_indices\.max\(\)"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor([0, 11, 0, 1]),
                                    torch.tensor(values),
                                    size,
                                    device=device)

    @onlyCUDA
    @dtypes(*get_all_dtypes())
    def test_factory_device_type_inference(self, device, dtype):
        cpu_cuda = ('cpu', 'cuda')
        cpu_cuda_none = cpu_cuda + (None, )
        for crow_indices_device, col_indices_device, values_device, device in itertools.product(
                cpu_cuda, cpu_cuda, cpu_cuda, cpu_cuda_none):
            for index_dtype in [torch.int32, torch.int64]:
                crow_indices = torch.tensor([0, 2, 4],
                                            dtype=index_dtype,
                                            device=crow_indices_device)
                col_indices = torch.tensor([0, 1, 0, 1],
                                           dtype=index_dtype,
                                           device=col_indices_device)
                values = torch.tensor([1, 2, 3, 4],
                                      dtype=dtype,
                                      device=values_device)
                if device is None and (
                        crow_indices_device != col_indices_device
                        or crow_indices_device != values_device):
                    with self.assertRaises(RuntimeError):
                        torch.sparse_csr_tensor(crow_indices,
                                                col_indices,
                                                values,
                                                size=(2, 10),
                                                device=device)
                else:
                    t = torch.sparse_csr_tensor(crow_indices,
                                                col_indices,
                                                values,
                                                size=(2, 10),
                                                device=device)
                    should_be_cuda = (device == 'cuda'
                                      or (device is None
                                          and values_device == 'cuda'))
                    self.assertEqual(should_be_cuda, t.is_cuda)
                    t.crow_indices().dtype == index_dtype
                    t.col_indices().dtype == index_dtype
                    t.values().dtype == dtype
                    t.crow_indices().device == t.values().device
                    t.col_indices().device == t.values().device

    def test_sparse_csr_print(self, device):
        orig_maxDiff = self.maxDiff
        self.maxDiff = None
        shape_nnz = [((10, 10), 10), ((100, 10), 10), ((1000, 10), 10)]
        printed = []
        for shape, nnz in shape_nnz:
            values_shape = torch.Size((nnz, ))
            col_indices_shape = torch.Size((nnz, ))
            crow_indices_shape = torch.Size((shape[0] + 1, ))
            printed.append("# shape: {}".format(torch.Size(shape)))
            printed.append("# nnz: {}".format(nnz))
            printed.append(
                "# crow_indices shape: {}".format(crow_indices_shape))
            printed.append("# col_indices shape: {}".format(col_indices_shape))
            printed.append("# values_shape: {}".format(values_shape))
            for index_dtype in [torch.int32, torch.int64]:
                for dtype in floating_types():
                    printed.append("########## {}/{} ##########".format(
                        dtype, index_dtype))
                    x = torch.sparse_csr_tensor(
                        torch.tensor([0, 2, 4], dtype=index_dtype),
                        torch.tensor([0, 1, 0, 1], dtype=index_dtype),
                        torch.tensor([1, 2, 3, 4]),
                        dtype=dtype,
                        device=device)
                    printed.append("# sparse tensor")
                    printed.append(str(x))
                    printed.append("# _crow_indices")
                    printed.append(str(x.crow_indices()))
                    printed.append("# _col_indices")
                    printed.append(str(x.col_indices()))
                    printed.append("# _values")
                    printed.append(str(x.values()))
                    printed.append('')
                printed.append('')
        self.assertExpected('\n'.join(printed))
        self.maxDiff = orig_maxDiff

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_from_dense(self, device, dtype):
        dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]],
                             dtype=dtype,
                             device=device)
        sparse = dense.to_sparse_csr()
        self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64),
                         sparse.crow_indices())
        self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64),
                         sparse.col_indices())
        self.assertEqual(torch.tensor([4, 5, 1], dtype=dtype), sparse.values())

        dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]],
                             dtype=dtype,
                             device=device)
        sparse = dense.to_sparse_csr()
        self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64),
                         sparse.crow_indices())
        self.assertEqual(torch.tensor([2, 0], dtype=torch.int64),
                         sparse.col_indices())
        self.assertEqual(torch.tensor([1, 1], dtype=dtype), sparse.values())

        dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]],
                             dtype=dtype,
                             device=device)
        sparse = dense.to_sparse_csr()
        self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64),
                         sparse.crow_indices())
        self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64),
                         sparse.col_indices())
        self.assertEqual(torch.tensor([2] * 9, dtype=dtype), sparse.values())

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_to_dense(self, device, dtype):
        mn = [5, 2, 0]
        for (m, n) in itertools.product(mn, mn):
            size = (m, n)
            dense = make_tensor(size, dtype=dtype, device=device)
            sparse = dense.to_sparse_csr()
            self.assertEqual(sparse.to_dense(), dense)

        crow_indices = torch.tensor([0, 3, 5])
        col_indices = torch.tensor([0, 1, 2, 0, 1])
        values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
        csr = torch.sparse_csr_tensor(crow_indices,
                                      col_indices,
                                      values,
                                      dtype=dtype,
                                      device=device)
        dense = torch.tensor([[1, 2, 1], [3, 4, 0]],
                             dtype=dtype,
                             device=device)
        self.assertEqual(csr.to_dense(), dense)

    @coalescedonoff
    @dtypes(torch.double)
    def test_coo_to_csr_convert(self, device, dtype, coalesced):
        with self.assertRaisesRegex(RuntimeError,
                                    "Input is supposed to be a vector"):
            torch._convert_indices_from_coo_to_csr(torch.randint(
                100, (5, 5), device=device),
                                                   size=100)

        size = (5, 5)
        sparse_dim = 2
        nnz = 10
        sparse_coo, _, _ = self.genSparseTensor(size, sparse_dim, nnz,
                                                coalesced, device, dtype)
        sparse_csr = sparse_coo.to_sparse_csr()

        self.assertTrue(sparse_csr.is_sparse_csr)
        self.assertEqual(sparse_csr.to_dense(), sparse_coo.to_dense())

        vec = torch.randn((5, 1), dtype=dtype, device=device)
        coo_product = sparse_coo.matmul(vec)
        csr_product = sparse_csr.matmul(vec)

        self.assertEqual(coo_product, csr_product)

        vec = torch.randn((100, 1), dtype=dtype, device=device)
        index = torch.tensor([
            [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
            [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
        ],
                             dtype=torch.int32)
        values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                              dtype=dtype,
                              device=device)
        coo = torch.sparse_coo_tensor(index,
                                      values,
                                      torch.Size([100, 100]),
                                      dtype=dtype,
                                      device=device)
        csr = coo.to_sparse_csr()

        self.assertEqual(coo.matmul(vec), csr.matmul(vec))

        col_indices = torch.tensor([31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
                                   dtype=torch.int64,
                                   device=device)
        self.assertEqual(csr.col_indices(), col_indices)

        values = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7],
                              dtype=dtype,
                              device=device)
        self.assertEqual(csr.values(), values)

    @onlyCPU
    @unittest.skipIf(IS_MACOS or IS_WINDOWS,
                     "MKL doesn't work on windows or mac")
    @dtypes(torch.float, torch.double)
    def test_mkl_matvec_warnings(self, device, dtype):
        if torch.has_mkl:
            for index_dtype in [torch.int32, torch.int64]:
                sp = torch.sparse_csr_tensor(
                    torch.tensor([0, 2, 4]), torch.tensor([0, 1, 0, 1]),
                    torch.tensor([1, 2, 3, 4], dtype=dtype, device=device))
                vec = torch.randn((2, 1), dtype=dtype, device=device)
                with warnings.catch_warnings(record=True) as w:
                    sp.matmul(vec)
                    self.assertEqual(len(w), 2)
                    self.assertIn(
                        "Pytorch is compiled with MKL LP64 and will convert crow_indices to int32",
                        str(w[0].message))
                    self.assertIn(
                        "Pytorch is compiled with MKL LP64 and will convert col_indices to int32",
                        str(w[1].message))

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_from_dense_convert_error(self, device, dtype):
        size = (4, 2, 4)
        dense = make_tensor(size, dtype=dtype, device=device)

        with self.assertRaisesRegex(RuntimeError, "Only 2D"):
            sparse = dense.to_sparse_csr()

    # TODO: Support auto generation of device check for sparse tensors
    # See: https://github.com/pytorch/pytorch/issues/59058
    @onlyCUDA
    @dtypes(torch.double)
    def test_matmul_device_mismatch(self, device, dtype):
        cpu = torch.rand((10, 10))
        cuda = cpu.cuda()
        for s, m1, m2 in itertools.product((cpu, cuda), repeat=3):
            csr = m1.to_sparse()
            if s.device == csr.device == m2.device:
                torch.addmm(s, csr, m2)
            else:
                with self.assertRaisesRegex(
                        RuntimeError,
                        "Expected all tensors to be on the same device"):
                    torch.addmm(s, csr, m2)

    @dtypes(torch.float, torch.double)
    def test_csr_matvec(self, device, dtype):
        side = 100
        for index_dtype in [torch.int32, torch.int64]:
            csr = self.genSparseCSRTensor((side, side),
                                          1000,
                                          device=device,
                                          dtype=dtype,
                                          index_dtype=index_dtype)
            vec = torch.randn(side, dtype=dtype, device=device)

            res = csr.matmul(vec)
            expected = csr.to_dense().matmul(vec)

            self.assertEqual(res, expected)

            bad_vec = torch.randn(side + 10, dtype=dtype, device=device)
            with self.assertRaisesRegex(RuntimeError, "mv: expected"):
                csr.matmul(bad_vec)

    @dtypes(torch.double)
    def test_mm(self, device, dtype):
        def test_shape(di, dj, dk, nnz):
            for index_dtype in [torch.int32, torch.int64]:
                x = self.genSparseCSRTensor((di, dj),
                                            nnz,
                                            device=device,
                                            dtype=dtype,
                                            index_dtype=index_dtype)
                t = torch.randn(di, dk, dtype=dtype, device=device)
                y = torch.randn(dj, dk, dtype=dtype, device=device)
                alpha = random.random()
                beta = random.random()

                # res = beta * t  + alpha * (x @ y)
                res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
                expected = torch.addmm(t,
                                       x.to_dense(),
                                       y,
                                       beta=beta,
                                       alpha=alpha)
                self.assertEqual(res, expected)

                res = torch.addmm(t, x, y)
                expected = torch.addmm(t, x.to_dense(), y)
                self.assertEqual(res, expected)

                res = torch.mm(x, y)
                expected = torch.mm(x.to_dense(), y)
                self.assertEqual(res, expected)

        for i in range(2, 5):
            for j in range(2, 8):
                for k in range(2, 8):
                    test_shape(i, j, k, i * j // 2)
        test_shape(4, 4, 4, 0)

    @dtypes(*floating_types())
    def test_sparse_mm(self, device, dtype):
        def test_shape(d1, d2, d3, nnz, transposed):
            if transposed:
                D = torch.randn(d3, d2, dtype=dtype, device=device).t_()
            else:
                D = torch.randn(d2, d3, dtype=dtype, device=device)
            S = self.genSparseCSRTensor((d1, d2),
                                        nnz,
                                        device=device,
                                        dtype=dtype,
                                        index_dtype=torch.int32)
            S_dense = S.to_dense()
            self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))

        test_shape(7, 8, 9, 20, False)
        test_shape(7, 8, 9, 20, True)

    @dtypes(*floating_types())
    def test_sparse_addmm(self, device, dtype):
        def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
            if alpha_beta is None:
                alpha = random.random()
                beta = random.random()
            else:
                alpha, beta = alpha_beta
            if broadcast:
                D1 = make_tensor((), dtype=dtype, device=device)
            else:
                D1 = make_tensor([n, p], dtype=dtype, device=device)
            D2 = make_tensor([m, p], dtype=dtype, device=device)
            S = self.genSparseCSRTensor([n, m],
                                        nnz,
                                        dtype=dtype,
                                        device=device,
                                        index_dtype=torch.int32)
            S_dense = S.to_dense()
            Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
            Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
            self.assertEqual(Y, Y_dense)

        test_shape(7, 8, 9, 20, False, None)
        test_shape(7, 8, 9, 20, True, None)
        test_shape(7, 8, 9, 20, False, (1, 0))
        test_shape(7, 8, 9, 20, True, (1, 0))
        test_shape(7, 8, 9, 20, False, (1, 1))
        test_shape(7, 8, 9, 20, True, (1, 1))

    @dtypes(torch.float, torch.double)
    def test_add(self, device, dtype):
        def _test_spadd_shape(nnz, shape):
            x = self.genSparseCSRTensor(shape,
                                        nnz,
                                        dtype=dtype,
                                        device=device,
                                        index_dtype=torch.int32)
            y = torch.randn(*shape, dtype=dtype, device=device)
            r = random.random()

            res = torch.add(y, x, alpha=r)
            expected = y + r * x.to_dense()
            self.assertEqual(res, expected)

            # Non contiguous dense tensor
            s = list(shape)
            s[0] = shape[-1]
            s[-1] = shape[0]
            y = torch.randn(*s, dtype=torch.double, device=device)
            y.transpose_(0, len(s) - 1)
            r = random.random()

            res = torch.add(y, x, alpha=r)
            expected = y + r * x.to_dense()

            self.assertEqual(res, expected)

        _test_spadd_shape(10, [100, 100])
        _test_spadd_shape(0, [100, 100])
        _test_spadd_shape(10, [100, 1])
        _test_spadd_shape(10, [1, 100])

    @dtypes(*get_all_dtypes())
    def test_coo_csr_conversion(self, device, dtype):
        for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
            size = (m, n)
            dense = make_tensor(size, dtype=dtype, device=device)
            coo_sparse = dense.to_sparse()
            csr_sparse = coo_sparse.to_sparse_csr()

            self.assertEqual(csr_sparse.to_dense(), dense)
Exemplo n.º 6
0
                        requires_grad=requires_grad)),
    )


op_db: List[OpInfo] = [
    UnaryUfuncInfo(
        "special.i0e",
        aten_name="special_i0e",
        ref=scipy.special.i0e if TEST_SCIPY else None,
        decorators=(precisionOverride({
            torch.bfloat16: 3e-1,
            torch.float16: 3e-1
        }), ),
        dtypes=all_types_and(torch.bool, torch.bfloat16),
        dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
        backward_dtypes=floating_types(),
        sample_inputs_func=sample_inputs_i0_i1,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
    ),
    UnaryUfuncInfo(
        "special.i1",
        aten_name="special_i1",
        ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1)
        if TEST_SCIPY else None,
        dtypes=all_types_and(torch.bool),
        dtypesIfCUDA=all_types_and(torch.bool),
        sample_inputs_func=sample_inputs_i0_i1,
        decorators=(DecorateInfo(
            toleranceOverride({
                torch.float32: tol(atol=1e-4, rtol=0),
Exemplo n.º 7
0
class TestSparseCSR(TestCase):

    @onlyCPU
    def test_csr_layout(self):
        self.assertEqual(str(torch.sparse_csr), 'torch.sparse_csr')
        self.assertEqual(type(torch.sparse_csr), torch.layout)

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_constructor_shape_inference(self, device, dtype):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
                                         torch.tensor(col_indices, dtype=torch.int64),
                                         torch.tensor(values), dtype=dtype, device=device)
        self.assertEqual(torch.tensor(crow_indices, dtype=torch.int64), sparse.crow_indices())
        self.assertEqual((len(crow_indices) - 1, max(col_indices) + 1), sparse.shape)
        self.assertEqual(dtype, sparse.dtype)
        self.assertEqual(torch.device(device), sparse.device)

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_constructor(self, device, dtype):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        for index_dtype in [torch.int32, torch.int64]:
            sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=index_dtype),
                                             torch.tensor(col_indices, dtype=index_dtype),
                                             torch.tensor(values),
                                             size=(2, 10),
                                             dtype=dtype,
                                             device=device)
            self.assertEqual((2, 10), sparse.shape)
            self.assertEqual(torch.tensor(crow_indices, dtype=index_dtype), sparse.crow_indices())
            self.assertEqual(torch.tensor(col_indices, dtype=index_dtype), sparse.col_indices())
            self.assertEqual(torch.tensor(values, dtype=dtype), sparse.values())

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_constructor_from_lists(self, device, dtype):
        # without size
        sparse = torch.sparse_csr_tensor([0, 2, 4],
                                         [0, 1, 0, 1],
                                         [1, 2, 3, 4],
                                         dtype=dtype,
                                         device=device)

        self.assertEqual((2, 2), sparse.shape)
        self.assertEqual(4, sparse.numel())
        self.assertEqual(torch.tensor([0, 2, 4], dtype=torch.int64, device=device), sparse.crow_indices())
        self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device), sparse.col_indices())
        self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), sparse.values())

        # with size
        for sparse_csr_tensor in [torch.sparse_csr_tensor, torch._sparse_csr_tensor_unsafe]:
            sparse = sparse_csr_tensor([0, 2, 4],
                                       [0, 1, 0, 1],
                                       [1, 2, 3, 4],
                                       size=(2, 10),
                                       dtype=dtype,
                                       device=device)

            self.assertEqual((2, 10), sparse.shape)
            self.assertEqual(torch.tensor([0, 2, 4], dtype=torch.int64, device=device), sparse.crow_indices())
            self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device), sparse.col_indices())
            self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), sparse.values())

    @skipMeta
    @dtypes(*get_all_dtypes())
    def test_empty(self, device, dtype):
        ns = [5, 2, 0]
        for shape in itertools.product(ns, ns):
            result = torch.empty(shape, dtype=dtype, device=device, layout=torch.sparse_csr)
            self.assertEqual(result.shape, shape)
            self.assertEqual(result.dtype, dtype)
            self.assertEqual(result.device, torch.device(device))
            self.assertEqual(result.layout, torch.sparse_csr)
            self.assertEqual(result.crow_indices().shape, (shape[0] + 1,))
            self.assertEqual(result.col_indices().shape, (0,))
            self.assertEqual(result.values().shape, (0,))
            self.assertEqual(result._nnz(), 0)
            self.assertEqual(result.crow_indices().device, torch.device(device))
            self.assertEqual(result.col_indices().device, torch.device(device))
            self.assertEqual(result.values().device, torch.device(device))
            self.assertEqual(result.crow_indices().dtype, torch.int64)
            self.assertEqual(result.col_indices().dtype, torch.int64)
            self.assertEqual(result.values().dtype, dtype)

    @skipMeta
    @dtypes(*get_all_dtypes())
    def test_empty_errors(self, device, dtype):
        with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
            torch.empty((5,), dtype=dtype, device=device, layout=torch.sparse_csr)

        with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
            torch.empty((2, 3, 4), dtype=dtype, device=device, layout=torch.sparse_csr)

    @skipMeta
    @dtypes(*get_all_dtypes())
    def test_copy(self, device, dtype):

        def run_test(shape, nnz, index_type):
            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
            b = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)

            a.copy_(b)

            self.assertEqual(a.crow_indices(), b.crow_indices())
            self.assertEqual(a.col_indices(), b.col_indices())
            self.assertEqual(a.values(), b.values())

        ns = [5, 2, 0]
        for shape, index_dtype in zip(itertools.product(ns, ns), [torch.int32, torch.int64]):
            run_test(shape, 0, index_dtype)
            run_test(shape, shape[0] * shape[1], index_dtype)

    @skipMeta
    @dtypes(*get_all_dtypes())
    def test_copy_errors(self, device, dtype):
        for index_dtype in [torch.int32, torch.int64]:
            shape1 = (2, 3)
            shape2 = (3, 2)
            a = self.genSparseCSRTensor(shape1, 0, dtype=dtype, device=device, index_dtype=index_dtype)
            b = self.genSparseCSRTensor(shape2, 0, dtype=dtype, device=device, index_dtype=index_dtype)

            with self.assertRaisesRegex(RuntimeError, "only same size tensors are supported."):
                a.copy_(b)

            with self.assertRaisesRegex(RuntimeError, "copy between different layouts is not supported."):
                a.copy_(torch.empty(a.shape, dtype=dtype, device=device))

            b = self.genSparseCSRTensor(shape1, 1, dtype=dtype, device=device, index_dtype=index_dtype)
            with self.assertRaisesRegex(RuntimeError, "only tensors with the same number of specified elements are supported."):
                a.copy_(b)

    @skipMeta
    @dtypes(*get_all_dtypes())
    def test_resize(self, device, dtype):
        for index_dtype in [torch.int32, torch.int64]:
            shape = (2, 3)
            nnz = 6
            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)

            new_shape = (4, 5)
            a.resize_(new_shape)

            self.assertEqual(a.shape, new_shape)
            # resize to larger shape doesn't add specified elements
            self.assertEqual(a._nnz(), nnz)

            new_shape = (1, 5)
            a.resize_(new_shape)

            self.assertEqual(a.shape, new_shape)
            # resize to smaller shape trims specified elements
            self.assertEqual(a._nnz(), 5)

    @skipMeta
    @dtypes(*get_all_dtypes())
    def test_resize_errors(self, device, dtype):
        for index_dtype in [torch.int32, torch.int64]:
            shape = (2, 3)
            nnz = 6
            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)

            with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only 2D sparse CSR tensors are supported."):
                new_shape = (4,)
                a.resize_(new_shape)

            # resizing of columns to smaller size is not implemented
            with self.assertRaisesRegex(
                RuntimeError,
                "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported.",
            ):
                new_shape = (2, 2)
                a.resize_(new_shape)

    def test_factory_type_invariants_check(self, device):
        with self.assertRaisesRegex(RuntimeError, "both crow_indices and col_indices should have the same type."):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int64),
                                    torch.tensor([0, 1, 0, 1], dtype=torch.int32),
                                    torch.tensor([1, 2, 3, 4]),
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"\"csr_construct_check\" not implemented for 'Short'"):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int16),
                                    torch.tensor([0, 1, 0, 1], dtype=torch.int16),
                                    torch.tensor([1, 2, 3, 4]),
                                    device=device)

    def test_factory_layout_invariants_check(self, device):
        with self.assertRaisesRegex(RuntimeError, "expected values to be a strided and contiguous tensor"):
            values = torch.tensor([1.], device=device).expand(4,)
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], device=device),
                                    torch.tensor([0, 1, 0, 1], device=device),
                                    values)

        with self.assertRaisesRegex(RuntimeError, "expected col_indices to be a strided and contiguous tensor"):
            col_indices = torch.tensor([0], device=device).expand(4,)
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4]),
                                    col_indices,
                                    torch.tensor([1, 2, 3, 4]))

        with self.assertRaisesRegex(RuntimeError, "expected crow_indices to be a strided and contiguous tensor"):
            crow_indices = torch.arange(6, device=device)
            torch.sparse_csr_tensor(crow_indices[::2],
                                    torch.tensor([0, 1, 0, 1], device=device),
                                    torch.tensor([1, 2, 3, 4]))

    def test_factory_shape_invariants_check(self, device):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        size = (2, 10)
        torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
                                device=device)

        with self.assertRaisesRegex(RuntimeError, r"size of a CSR tensor must be of length 2, but got: 3"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values),
                                    size=(2, 10, 2),
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim\=1 but got crow_indices\.dim\(\)\=2"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices).repeat(2, 1),
                                    torch.tensor(col_indices),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim\=1 but got col_indices\.dim\(\)\=2"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices).repeat(2, 1),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"values must have dim\=1 but got values\.dim\(\)\=2"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices),
                                    torch.tensor(values).repeat(2, 1),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError,
                                    r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
                                    device=device)


        with self.assertRaisesRegex(RuntimeError,
                                    r"col_indices and values must have equal sizes, " +
                                    r"but got col_indices\.numel\(\): 3, values\.numel\(\): 4"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 1, 0]), torch.tensor(values), size,
                                    device=device)

    def test_factory_indices_invariants_check(self, device):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        size = (2, 10)
        with self.assertRaisesRegex(RuntimeError, "0th value of crow_indices must be 0."):
            torch.sparse_csr_tensor(torch.tensor([-1, 0, 4]), torch.tensor(col_indices), torch.tensor(values), size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError,
                                    "last value of crow_indices should be equal to the length of col_indices."):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 5]), torch.tensor(col_indices), torch.tensor(values), size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError,
                                    r"at position i \= 2," +
                                    r" this condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
            torch.sparse_csr_tensor(torch.tensor([0, 5, 4]), torch.tensor(col_indices), torch.tensor(values), size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"col_indices\.min\(\) should be greater or equal to zero"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, -1, 0, 1]), torch.tensor(values), size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"size\(1\) should be greater than col_indices\.max\(\)"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 11, 0, 1]), torch.tensor(values), size,
                                    device=device)

    @onlyCUDA
    @dtypes(*get_all_dtypes())
    def test_factory_device_type_inference(self, device, dtype):
        cpu_cuda = ('cpu', 'cuda')
        cpu_cuda_none = cpu_cuda + (None,)
        for crow_indices_device, col_indices_device, values_device, device in itertools.product(cpu_cuda,
                                                                                                cpu_cuda,
                                                                                                cpu_cuda,
                                                                                                cpu_cuda_none):
            for index_dtype in [torch.int32, torch.int64]:
                crow_indices = torch.tensor([0, 2, 4], dtype=index_dtype, device=crow_indices_device)
                col_indices = torch.tensor([0, 1, 0, 1], dtype=index_dtype, device=col_indices_device)
                values = torch.tensor([1, 2, 3, 4], dtype=dtype, device=values_device)
                if device is None and (crow_indices_device != col_indices_device or
                                       crow_indices_device != values_device):
                    with self.assertRaises(RuntimeError):
                        torch.sparse_csr_tensor(crow_indices,
                                                col_indices,
                                                values,
                                                size=(2, 10),
                                                device=device)
                else:
                    t = torch.sparse_csr_tensor(crow_indices,
                                                col_indices,
                                                values,
                                                size=(2, 10),
                                                device=device)
                    should_be_cuda = (device == 'cuda' or (device is None and values_device == 'cuda'))
                    self.assertEqual(should_be_cuda, t.is_cuda)
                    t.crow_indices().dtype == index_dtype
                    t.col_indices().dtype == index_dtype
                    t.values().dtype == dtype
                    t.crow_indices().device == t.values().device
                    t.col_indices().device == t.values().device

    def test_sparse_csr_print(self, device):
        orig_maxDiff = self.maxDiff
        self.maxDiff = None
        shape_nnz = [
            ((10, 10), 10),
            ((100, 10), 10),
            ((1000, 10), 10)
        ]
        printed = []
        for shape, nnz in shape_nnz:
            values_shape = torch.Size((nnz,))
            col_indices_shape = torch.Size((nnz,))
            crow_indices_shape = torch.Size((shape[0] + 1,))
            printed.append("# shape: {}".format(torch.Size(shape)))
            printed.append("# nnz: {}".format(nnz))
            printed.append("# crow_indices shape: {}".format(crow_indices_shape))
            printed.append("# col_indices shape: {}".format(col_indices_shape))
            printed.append("# values_shape: {}".format(values_shape))
            for index_dtype in [torch.int32, torch.int64]:
                for dtype in floating_types():
                    printed.append("########## {}/{} ##########".format(dtype, index_dtype))
                    x = torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=index_dtype),
                                                torch.tensor([0, 1, 0, 1], dtype=index_dtype),
                                                torch.tensor([1, 2, 3, 4]), dtype=dtype, device=device)
                    printed.append("# sparse tensor")
                    printed.append(str(x))
                    printed.append("# _crow_indices")
                    printed.append(str(x.crow_indices()))
                    printed.append("# _col_indices")
                    printed.append(str(x.col_indices()))
                    printed.append("# _values")
                    printed.append(str(x.values()))
                    printed.append('')
                printed.append('')
        self.assertExpected('\n'.join(printed))
        self.maxDiff = orig_maxDiff

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_from_dense(self, device, dtype):
        dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], dtype=dtype, device=device)
        sparse = dense.to_sparse_csr()
        self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64), sparse.crow_indices())
        self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64), sparse.col_indices())
        self.assertEqual(torch.tensor([4, 5, 1], dtype=dtype), sparse.values())

        dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=dtype, device=device)
        sparse = dense.to_sparse_csr()
        self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64), sparse.crow_indices())
        self.assertEqual(torch.tensor([2, 0], dtype=torch.int64), sparse.col_indices())
        self.assertEqual(torch.tensor([1, 1], dtype=dtype), sparse.values())

        dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]], dtype=dtype, device=device)
        sparse = dense.to_sparse_csr()
        self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64), sparse.crow_indices())
        self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices())
        self.assertEqual(torch.tensor([2] * 9, dtype=dtype), sparse.values())

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_to_dense(self, device, dtype):
        mn = [5, 2, 0]
        for (m, n) in itertools.product(mn, mn):
            size = (m, n)
            dense = make_tensor(size, dtype=dtype, device=device)
            sparse = dense.to_sparse_csr()
            self.assertEqual(sparse.to_dense(), dense)

        crow_indices = torch.tensor([0, 3, 5])
        col_indices = torch.tensor([0, 1, 2, 0, 1])
        values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
        csr = torch.sparse_csr_tensor(crow_indices, col_indices,
                                      values, dtype=dtype, device=device)
        dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
        self.assertEqual(csr.to_dense(), dense)

    @coalescedonoff
    @dtypes(torch.double)
    def test_coo_to_csr_convert(self, device, dtype, coalesced):
        with self.assertRaisesRegex(RuntimeError, "Input is supposed to be a vector"):
            torch._convert_indices_from_coo_to_csr(
                torch.randint(100, (5, 5), device=device),
                size=100)

        size = (5, 5)
        sparse_dim = 2
        nnz = 10
        sparse_coo, _, _ = self.genSparseTensor(size, sparse_dim, nnz, coalesced, device, dtype)
        sparse_csr = sparse_coo.to_sparse_csr()

        self.assertTrue(sparse_csr.is_sparse_csr)
        self.assertEqual(sparse_csr.to_dense(), sparse_coo.to_dense())

        vec = torch.randn((5, 1), dtype=dtype, device=device)
        coo_product = sparse_coo.matmul(vec)
        csr_product = sparse_csr.matmul(vec)

        self.assertEqual(coo_product, csr_product)

        vec = torch.randn((100, 1), dtype=dtype, device=device)
        index = torch.tensor([
            [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
            [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
        ], dtype=torch.int32)
        values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device)
        coo = torch.sparse_coo_tensor(index, values, torch.Size([100, 100]), dtype=dtype, device=device)
        csr = coo.to_sparse_csr()

        self.assertEqual(coo.matmul(vec), csr.matmul(vec))

        col_indices = torch.tensor([
            31, 92, 65, 50, 34, 62, 22, 56, 74, 89
        ], dtype=torch.int64, device=device)
        self.assertEqual(csr.col_indices(), col_indices)

        values = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7], dtype=dtype, device=device)
        self.assertEqual(csr.values(), values)

    @onlyCPU
    @unittest.skipIf(IS_MACOS or IS_WINDOWS, "MKL doesn't work on windows or mac")
    @dtypes(torch.float, torch.double)
    def test_mkl_matvec_warnings(self, device, dtype):
        if torch.has_mkl:
            for index_dtype in [torch.int32, torch.int64]:
                sp = torch.sparse_csr_tensor(torch.tensor([0, 2, 4]),
                                             torch.tensor([0, 1, 0, 1]),
                                             torch.tensor([1, 2, 3, 4], dtype=dtype, device=device))
                vec = torch.randn((2, 1), dtype=dtype, device=device)
                with warnings.catch_warnings(record=True) as w:
                    sp.matmul(vec)
                    self.assertEqual(len(w), 2)
                    self.assertIn("Pytorch is compiled with MKL LP64 and will convert crow_indices to int32",
                                  str(w[0].message))
                    self.assertIn("Pytorch is compiled with MKL LP64 and will convert col_indices to int32",
                                  str(w[1].message))

    @dtypes(*get_all_dtypes())
    def test_sparse_csr_from_dense_convert_error(self, device, dtype):
        size = (4, 2, 4)
        dense = make_tensor(size, dtype=dtype, device=device)

        with self.assertRaisesRegex(RuntimeError, "Only 2D"):
            sparse = dense.to_sparse_csr()

    # TODO: Support auto generation of device check for sparse tensors
    # See: https://github.com/pytorch/pytorch/issues/59058
    @onlyCUDA
    @dtypes(torch.double)
    def test_matmul_device_mismatch(self, device, dtype):
        cpu = torch.rand((10, 10))
        cuda = cpu.cuda()
        for s, m1, m2 in itertools.product((cpu, cuda), repeat=3):
            csr = m1.to_sparse()
            if s.device == csr.device == m2.device:
                torch.addmm(s, csr, m2)
            else:
                with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
                    torch.addmm(s, csr, m2)

    @skipCUDAIfNoCusparseGeneric
    @dtypes(*torch.testing.floating_types())
    @dtypesIfCUDA(*get_all_complex_dtypes(),
                  *get_all_fp_dtypes(include_half=SM53OrLater, include_bfloat16=SM80OrLater))
    def test_csr_matvec(self, device, dtype):
        side = 100
        for index_dtype in [torch.int32, torch.int64]:
            csr = self.genSparseCSRTensor((side, side), 1000, device=device, dtype=dtype, index_dtype=index_dtype)
            vec = torch.randn(side, dtype=dtype, device=device)

            res = csr.matmul(vec)
            expected = csr.to_dense().matmul(vec)

            self.assertEqual(res, expected)

            bad_vec = torch.randn(side + 10, dtype=dtype, device=device)
            err_msg = "mv: expected"
            # CUDA path now uses generic meta/structured implementation
            # TODO: move CPU path to not use `mv_sparse` function
            if self.device_type == 'cuda':
                err_msg = "size mismatch, got"
            with self.assertRaisesRegex(RuntimeError, err_msg):
                csr.matmul(bad_vec)

    @dtypes(torch.double)
    def test_mm(self, device, dtype):
        def test_shape(di, dj, dk, nnz):
            for index_dtype in [torch.int32, torch.int64]:
                x = self.genSparseCSRTensor((di, dj), nnz, device=device, dtype=dtype, index_dtype=index_dtype)
                t = torch.randn(di, dk, dtype=dtype, device=device)
                y = torch.randn(dj, dk, dtype=dtype, device=device)
                alpha = random.random()
                beta = random.random()

                # res = beta * t  + alpha * (x @ y)
                res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
                expected = torch.addmm(t, x.to_dense(), y, beta=beta, alpha=alpha)
                self.assertEqual(res, expected)

                res = torch.addmm(t, x, y)
                expected = torch.addmm(t, x.to_dense(), y)
                self.assertEqual(res, expected)

                res = torch.mm(x, y)
                expected = torch.mm(x.to_dense(), y)
                self.assertEqual(res, expected)

        for i in range(2, 5):
            for j in range(2, 8):
                for k in range(2, 8):
                    test_shape(i, j, k, i * j // 2)
        test_shape(4, 4, 4, 0)

    @dtypes(*floating_types())
    @dtypesIfCUDA(*get_all_complex_dtypes(),
                  *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
                                     include_bfloat16=SM80OrLater and TEST_CUSPARSE_GENERIC))
    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
    def test_sparse_mm(self, device, dtype):
        def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
            if transposed:
                D = torch.randn(d3, d2, dtype=dtype, device=device).t_()
            else:
                D = torch.randn(d2, d3, dtype=dtype, device=device)
            S = self.genSparseCSRTensor((d1, d2), nnz, device=device, dtype=dtype, index_dtype=index_dtype)
            S_dense = S.to_dense()
            self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))

        for index_dtype in [torch.int32, torch.int64]:
            test_shape(7, 8, 9, 20, False, index_dtype)
            test_shape(7, 8, 9, 20, True, index_dtype)

    @dtypes(*floating_types())
    @dtypesIfCUDA(*get_all_complex_dtypes(),
                  *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
                                     include_bfloat16=SM80OrLater and TEST_CUSPARSE_GENERIC))
    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
    def test_sparse_addmm(self, device, dtype):
        def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
            if alpha_beta is None:
                alpha = random.random()
                beta = random.random()
            else:
                alpha, beta = alpha_beta
            if broadcast:
                D1 = make_tensor((), dtype=dtype, device=device)
            else:
                D1 = make_tensor([n, p], dtype=dtype, device=device)
            D2 = make_tensor([m, p], dtype=dtype, device=device)
            S = self.genSparseCSRTensor([n, m], nnz, dtype=dtype, device=device, index_dtype=index_dtype)
            S_dense = S.to_dense()
            Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
            Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
            self.assertEqual(Y, Y_dense)

        for index_dtype in [torch.int32, torch.int64]:
            test_shape(7, 8, 9, 20, False, index_dtype, None)
            test_shape(7, 8, 9, 20, True, index_dtype, None)
            test_shape(7, 8, 9, 20, False, index_dtype, (1, 0))
            test_shape(7, 8, 9, 20, True, index_dtype, (1, 0))
            test_shape(7, 8, 9, 20, False, index_dtype, (1, 1))
            test_shape(7, 8, 9, 20, True, index_dtype, (1, 1))

    @onlyCUDA
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
    @dtypesIfCUDA(torch.complex64,
                  *((torch.complex128,) if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else ()),
                  *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater,
                                                   include_half=SM53OrLater))
    @skipCUDAIf(
        not _check_cusparse_spgemm_available(),
        "cuSparse Generic API SpGEMM is not available"
    )
    def test_addmm_all_sparse_csr(self, device, dtype):
        M = torch.randn(10, 25, device=device).to(dtype)
        m1 = torch.randn(10, 50, device=device).to(dtype)
        m2 = torch.randn(50, 25, device=device).to(dtype)
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=torch.sparse_csr, all_sparse=True)

        # Test 0-strided
        M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
        m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
        m2 = torch.randn(50, 25, device=device).to(dtype)
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=torch.sparse_csr, all_sparse=True)

        # Test beta=0, M=nan
        M = torch.full((10, 25), float('nan'), device=device).to(dtype)
        m1 = torch.randn(10, 50, device=device).to(dtype)
        m2 = torch.randn(50, 25, device=device).to(dtype)
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=torch.sparse_csr, all_sparse=True)

        # Test transpose
        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
            def maybe_transpose(cond, m):
                if not cond:
                    return m
                return m.t().clone(memory_format=torch.contiguous_format).t()

            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
            _test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=torch.sparse_csr, all_sparse=True)

    @onlyCUDA
    @dtypesIfCUDA(torch.complex64,
                  *((torch.complex128,) if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else ()),
                  *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater,
                                                   include_half=SM53OrLater))
    @skipCUDAIf(
        not _check_cusparse_spgemm_available(),
        "cuSparse Generic API SpGEMM is not available"
    )
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
    def test_addmm_sizes_all_sparse_csr(self, device, dtype):
        for m in [0, 1, 25]:
            for n in [0, 1, 10]:
                for k in [0, 1, 8]:
                    M = torch.randn(n, m, device=device).to(dtype)
                    m1 = torch.randn(n, k, device=device).to(dtype)
                    m2 = torch.randn(k, m, device=device).to(dtype)
                    _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=torch.sparse_csr, all_sparse=True)

                    M = torch.randn(n, m, device=device).to(dtype).to_sparse_csr()
                    m1 = torch.randn(n, k + 1, device=device).to(dtype).to_sparse_csr()
                    m2 = torch.randn(k, m, device=device).to(dtype).to_sparse_csr()
                    self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2))
                    self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2))

    @onlyCUDA
    @dtypes(torch.float)
    def test_addmm_errors(self, device, dtype):
        # test that the errors are the same for dense and sparse versions
        import re

        def test1(*, is_sparse):
            # shapes must be compatible for matrix multiplication
            a = make_tensor((2, 3), dtype=dtype, device=device)
            if is_sparse:
                a_sparse = a.to_sparse_csr()
                return torch.addmm(a, a_sparse, a)
            else:
                return torch.addmm(a, a, a)

        def test2(*, is_sparse):
            # mat2 must be a matrix
            a = make_tensor((2, 3), dtype=dtype, device=device)
            if is_sparse:
                a_sparse = a.to_sparse_csr()
                return torch.addmm(a, a_sparse, a.unsqueeze(0))
            else:
                return torch.addmm(a, a, a.unsqueeze(0))

        def test3(*, is_sparse):
            # the first input needs to be 1D or 2D
            a = make_tensor((3, 3), dtype=dtype, device=device)
            if is_sparse:
                a_sparse = a.to_sparse_csr()
                return torch.addmm(a.unsqueeze(0), a_sparse, a)
            else:
                return torch.addmm(a.unsqueeze(0), a, a)

        for test in (test1, test2, test3):
            try:
                test(is_sparse=False)
            except RuntimeError as msg:
                with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
                    test(is_sparse=True)

    @onlyCUDA
    @dtypes(torch.float)
    def test_mm_errors(self, device, dtype):
        # test that the errors are the same for dense and sparse versions
        import re

        def test1(*, is_sparse):
            # shapes must be compatible for matrix multiplication
            a = make_tensor((2, 3), dtype=dtype, device=device)
            if is_sparse:
                a_sparse = a.to_sparse_csr()
                return torch.mm(a_sparse, a)
            else:
                return torch.mm(a, a)

        def test2(*, is_sparse):
            # mat2 must be a matrix
            a = make_tensor((2, 3), dtype=dtype, device=device)
            if is_sparse:
                a_sparse = a.to_sparse_csr()
                return torch.mm(a_sparse, a.unsqueeze(0))
            else:
                return torch.mm(a, a.unsqueeze(0))

        for test in (test1, test2):
            try:
                test(is_sparse=False)
            except RuntimeError as msg:
                with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
                    test(is_sparse=True)

    @dtypes(torch.float, torch.double)
    def test_add(self, device, dtype):
        def _test_spadd_shape(nnz, shape):
            x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
            y = torch.randn(*shape, dtype=dtype, device=device)
            r = random.random()

            res = torch.add(y, x, alpha=r)
            expected = y + r * x.to_dense()
            self.assertEqual(res, expected)

            # Non contiguous dense tensor
            s = list(shape)
            s[0] = shape[-1]
            s[-1] = shape[0]
            y = torch.randn(*s, dtype=torch.double, device=device)
            y.transpose_(0, len(s) - 1)
            r = random.random()

            res = torch.add(y, x, alpha=r)
            expected = y + r * x.to_dense()

            self.assertEqual(res, expected)

        _test_spadd_shape(10, [100, 100])
        _test_spadd_shape(0, [100, 100])
        _test_spadd_shape(10, [100, 1])
        _test_spadd_shape(10, [1, 100])

    @onlyCUDA
    @skipCUDAIf(
        not _check_cusparse_triangular_solve_available(),
        "cuSparse Generic API SpSV is not available"
    )
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
                        torch.float64: 1e-8, torch.complex128: 1e-8})
    def test_sparse_triangular_solve(self, device, dtype):

        def run_test(n, k, upper, unitriangular, transpose):
            triangle_function = torch.triu if upper else torch.tril
            A = make_tensor((n, n), dtype=dtype, device=device)
            A = triangle_function(A)
            A_sparse = A.to_sparse_csr()
            B = make_tensor((n, k), dtype=dtype, device=device)

            expected = torch.triangular_solve(B, A, upper=upper, unitriangular=unitriangular, transpose=transpose)
            expected_X = expected.solution

            actual = torch.triangular_solve(B, A_sparse, upper=upper, unitriangular=unitriangular, transpose=transpose)
            actual_X = actual.solution
            actual_A_clone = actual.cloned_coefficient
            self.assertTrue(actual_A_clone.numel() == 0)
            self.assertEqual(actual_X, expected_X)

            # test out with C contiguous strides
            out = torch.empty_strided((n, k), (k, 1), dtype=dtype, device=device)
            torch.triangular_solve(
                B, A_sparse,
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
            )
            self.assertEqual(out, expected_X)

            # test out with F contiguous strides
            # TODO (@ivanyashchuk): mixed memory format doesn't work yet for cuda
            # out is F contiguous but B is C contiguous
            if self.device_type == 'cuda' and (n > 0 and k > 1):
                with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"):
                    out = torch.empty_strided((n, k), (1, n), dtype=dtype, device=device)
                    torch.triangular_solve(
                        B, A_sparse,
                        upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
                    )
            else:
                out = torch.empty_strided((n, k), (1, n), dtype=dtype, device=device)
                torch.triangular_solve(
                    B, A_sparse,
                    upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
                )
                self.assertEqual(out, expected_X)
                self.assertEqual(out.stride(), (1, n))

            # test out with discontiguous strides
            out = torch.empty_strided((2 * n, k), (1, 2 * n), dtype=dtype, device=device)[::2]
            if n > 0 and k > 0:
                self.assertFalse(out.is_contiguous())
                self.assertFalse(out.t().is_contiguous())
            before_stride = out.stride()
            torch.triangular_solve(
                B, A_sparse,
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
            )
            self.assertEqual(out, expected_X)
            self.assertEqual(out.stride(), before_stride)

        ks = [0, 1, 3]
        ns = [5, 3, 0]
        for (k, n), (upper, unitriangular, transpose) in itertools.product(itertools.product(ks, ns),
                                                                           itertools.product([True, False], repeat=3)):
            run_test(n, k, upper, unitriangular, transpose)

    @dtypes(*get_all_dtypes())
    def test_coo_csr_conversion(self, device, dtype):
        for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
            size = (m, n)
            dense = make_tensor(size, dtype=dtype, device=device)
            coo_sparse = dense.to_sparse()
            csr_sparse = coo_sparse.to_sparse_csr()

            self.assertEqual(csr_sparse.to_dense(), dense)