def test_factory_device_type_inference(self, device, dtype):
     cpu_cuda = ('cpu', 'cuda')
     cpu_cuda_none = cpu_cuda + (None,)
     for crow_indices_device, col_indices_device, values_device, device in itertools.product(cpu_cuda,
                                                                                             cpu_cuda,
                                                                                             cpu_cuda,
                                                                                             cpu_cuda_none):
         for index_dtype in [torch.int32, torch.int64]:
             crow_indices = torch.tensor([0, 2, 4], dtype=index_dtype, device=crow_indices_device)
             col_indices = torch.tensor([0, 1, 0, 1], dtype=index_dtype, device=col_indices_device)
             values = torch.tensor([1, 2, 3, 4], dtype=dtype, device=values_device)
             if device is None and (crow_indices_device != col_indices_device or
                                    crow_indices_device != values_device):
                 with self.assertRaises(RuntimeError):
                     torch.sparse_csr_tensor(crow_indices,
                                             col_indices,
                                             values,
                                             size=(2, 10),
                                             device=device)
             else:
                 t = torch.sparse_csr_tensor(crow_indices,
                                             col_indices,
                                             values,
                                             size=(2, 10),
                                             device=device)
                 should_be_cuda = (device == 'cuda' or (device is None and values_device == 'cuda'))
                 self.assertEqual(should_be_cuda, t.is_cuda)
                 t.crow_indices().dtype == index_dtype
                 t.col_indices().dtype == index_dtype
                 t.values().dtype == dtype
                 t.crow_indices().device == t.values().device
                 t.col_indices().device == t.values().device
    def test_factory_indices_invariants_check(self, device):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        size = (2, 10)
        with self.assertRaisesRegex(RuntimeError, "0th value of crow_indices must be 0."):
            torch.sparse_csr_tensor(torch.tensor([-1, 0, 4]), torch.tensor(col_indices), torch.tensor(values), size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError,
                                    "last value of crow_indices should be equal to the length of col_indices."):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 5]), torch.tensor(col_indices), torch.tensor(values), size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError,
                                    r"at position i \= 2," +
                                    r" this condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
            torch.sparse_csr_tensor(torch.tensor([0, 5, 4]), torch.tensor(col_indices), torch.tensor(values), size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"col_indices\.min\(\) should be greater or equal to zero"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, -1, 0, 1]), torch.tensor(values), size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"size\(1\) should be greater than col_indices\.max\(\)"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 11, 0, 1]), torch.tensor(values), size,
                                    device=device)
    def test_sparse_csr_constructor(self, device, dtype):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        for index_dtype in [torch.int32, torch.int64]:
            sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices,
                                                          dtype=index_dtype),
                                             torch.tensor(col_indices,
                                                          dtype=index_dtype),
                                             torch.tensor(values),
                                             size=(2, 10),
                                             dtype=dtype,
                                             device=device)
            self.assertEqual((2, 10), sparse.shape)
            self.assertEqual(torch.tensor(crow_indices, dtype=index_dtype),
                             sparse.crow_indices())
            self.assertEqual(torch.tensor(col_indices, dtype=index_dtype),
                             sparse.col_indices())
            self.assertEqual(torch.tensor(values, dtype=dtype),
                             sparse.values())

        with self.assertRaises(RuntimeError):
            torch.sparse_csr_tensor(crow_indices,
                                    torch.tensor(col_indices),
                                    values,
                                    size=(2, 10))
    def test_factory_type_invariants_check(self, device):
        with self.assertRaisesRegex(RuntimeError, "both crow_indices and col_indices should have the same type."):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int64),
                                    torch.tensor([0, 1, 0, 1], dtype=torch.int32),
                                    torch.tensor([1, 2, 3, 4]),
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"\"csr_construct_check\" not implemented for 'Short'"):
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int16),
                                    torch.tensor([0, 1, 0, 1], dtype=torch.int16),
                                    torch.tensor([1, 2, 3, 4]),
                                    device=device)
    def test_factory_size_check(self, device, dtype):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        size = (2, 10)
        torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                torch.tensor(col_indices),
                                torch.tensor(values),
                                size,
                                dtype=dtype,
                                device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices),
                                    torch.tensor(values), (1, 1),
                                    dtype=dtype,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError,
                                    "0th value of crow_indices must be 0"):
            torch.sparse_csr_tensor(torch.tensor([-1, -1, -1]),
                                    torch.tensor(col_indices),
                                    torch.tensor(values),
                                    size,
                                    dtype=dtype,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                "last value of crow_indices should be less than length of col_indices."
        ):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor([0, 0, 0]),
                                    torch.tensor(values),
                                    size,
                                    dtype=dtype,
                                    device=device)

        with self.assertRaisesRegex(
                RuntimeError,
                r"col_indices and values must have equal sizes, " +
                r"but got col_indices\.size\(0\): 4, values\.size\(0\): 5"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices),
                                    torch.tensor([0, 0, 0, 0, 0]),
                                    size,
                                    dtype=dtype,
                                    device=device)
Exemple #6
0
    def to_sparse_csr(self):
        """ Convert a tensor to compressed row storage format. Only works with 2D tensors.

        Examples::

            >>> dense = torch.randn(5, 5)
            >>> sparse = dense.to_sparse_csr()
            >>> sparse._nnz()
            25

        """
        shape = self.size()
        fill_value = 0
        if len(shape) != 2:
            raise RuntimeError("Only 2D tensors can be converted to the CSR format but got shape: ", shape)

        if self.is_sparse:
            coalesced_self = self.coalesce()
            row_indices = coalesced_self.indices()[0]
            ro = [0]
            i = 0
            for irow in range(self.shape[0]):
                while i < row_indices.size()[0] and row_indices[i] == irow:
                    i += 1
                ro.append(i)

            return torch.sparse_csr_tensor(torch.tensor(ro, dtype=row_indices.dtype),
                                           coalesced_self.indices()[1], coalesced_self.values(),
                                           size=coalesced_self.shape, dtype=coalesced_self.dtype)
        elif self.is_sparse_csr:
            return self
        else:
            return self.to_sparse().to_sparse_csr()
    def test_sparse_csr_constructor_from_lists(self, device, dtype):
        # without size
        sparse = torch.sparse_csr_tensor([0, 2, 4],
                                         [0, 1, 0, 1],
                                         [1, 2, 3, 4],
                                         dtype=dtype,
                                         device=device)

        self.assertEqual((2, 2), sparse.shape)
        self.assertEqual(4, sparse.numel())
        self.assertEqual(torch.tensor([0, 2, 4], dtype=torch.int64, device=device), sparse.crow_indices())
        self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device), sparse.col_indices())
        self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), sparse.values())

        # with size
        for sparse_csr_tensor in [torch.sparse_csr_tensor, torch._sparse_csr_tensor_unsafe]:
            sparse = sparse_csr_tensor([0, 2, 4],
                                       [0, 1, 0, 1],
                                       [1, 2, 3, 4],
                                       size=(2, 10),
                                       dtype=dtype,
                                       device=device)

            self.assertEqual((2, 10), sparse.shape)
            self.assertEqual(torch.tensor([0, 2, 4], dtype=torch.int64, device=device), sparse.crow_indices())
            self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device), sparse.col_indices())
            self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), sparse.values())
Exemple #8
0
    def to_sparse_csr(self):
        """ Convert a tensor to compressed row storage format. Only works with 2D tensors.

        Examples::

            >>> dense = torch.randn(5, 5)
            >>> sparse = dense.to_sparse_csr()
            >>> sparse._nnz()
            25

        """
        shape = self.size()
        fill_value = 0
        if len(shape) != 2:
            raise RuntimeError("Only 2D tensors can be converted to the CSR format but got shape: ", shape)

        if self.is_sparse:
            coalesced_self = self.coalesce()
            row_indices = coalesced_self.indices()[0]
            device = coalesced_self.values().device
            crow_indices = torch._convert_indices_from_coo_to_csr(
                row_indices, self.shape[0], out_int32=row_indices.dtype == torch.int32)
            return torch.sparse_csr_tensor(crow_indices,
                                           coalesced_self.indices()[1].contiguous(),
                                           coalesced_self.values(),
                                           size=coalesced_self.shape,
                                           dtype=coalesced_self.dtype,
                                           device=device)
        elif self.is_sparse_csr:
            return self
        else:
            return self.to_sparse().to_sparse_csr()
 def test_mkl_matvec_warnings(self):
     if torch.has_mkl:
         sp = torch.sparse_csr_tensor(
             torch.tensor([0, 2, 4]), torch.tensor([0, 1, 0, 1]),
             torch.tensor([1, 2, 3, 4], dtype=torch.double))
         vec = torch.randn((2, 1))
         with warnings.catch_warnings(record=True) as w:
             sp.matmul(vec)
             self.assertEqual(len(w), 2)
Exemple #10
0
def get_clusters(k, values, indices, pointers, num_col):
    end = pointers[k]
    sp_clusters = (torch.sparse_csr_tensor(
        pointers[:(k + 1)],
        indices[:end],
        values[:end],
        requires_grad=True,
        dtype=torch.float32,
        size=(k, num_col),
    ).to_dense())
    return sp_clusters
 def test_sparse_csr_constructor_shape_inference(self, device, dtype):
     crow_indices = [0, 2, 4]
     col_indices = [0, 1, 0, 1]
     values = [1, 2, 3, 4]
     sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
                                      torch.tensor(col_indices, dtype=torch.int64),
                                      torch.tensor(values), dtype=dtype, device=device)
     self.assertEqual(torch.tensor(crow_indices, dtype=torch.int64), sparse.crow_indices())
     self.assertEqual((len(crow_indices) - 1, max(col_indices) + 1), sparse.shape)
     self.assertEqual(dtype, sparse.dtype)
     self.assertEqual(torch.device(device), sparse.device)
Exemple #12
0
    def test_factory_override(self):
        class A(TorchFunctionMode):
            def __torch_function__(self, *args, **kwargs):
                return -1

        with torch.overrides.push_torch_function_mode(A):
            self.assertEqual(torch.tensor([1]), -1)
            self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1)
            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1)
            self.assertEqual(torch._sparse_coo_tensor_unsafe(1, 1, (1, 1)), -1)
            self.assertEqual(torch._sparse_csr_tensor_unsafe(1, 1, 1, (1, 1)), -1)
            self.assertEqual(torch.as_tensor([1]), -1)
Exemple #13
0
 def test_mkl_matvec_warnings(self, device, dtype):
     if torch.has_mkl:
         for index_dtype in [torch.int32, torch.int64]:
             sp = torch.sparse_csr_tensor(torch.tensor([0, 2, 4]),
                                          torch.tensor([0, 1, 0, 1]),
                                          torch.tensor([1, 2, 3, 4], dtype=dtype, device=device))
             vec = torch.randn((2, 1), dtype=dtype, device=device)
             with warnings.catch_warnings(record=True) as w:
                 sp.matmul(vec)
                 self.assertEqual(len(w), 2)
                 self.assertIn("Pytorch is compiled with MKL LP64 and will convert crow_indices to int32",
                               str(w[0].message))
                 self.assertIn("Pytorch is compiled with MKL LP64 and will convert col_indices to int32",
                               str(w[1].message))
    def test_sparse_csr_to_dense(self, device, dtype):
        mn = [5, 2, 0]
        for (m, n) in itertools.product(mn, mn):
            size = (m, n)
            dense = make_tensor(size, dtype=dtype, device=device)
            sparse = dense.to_sparse_csr()
            self.assertEqual(sparse.to_dense(), dense)

        crow_indices = torch.tensor([0, 3, 5])
        col_indices = torch.tensor([0, 1, 2, 0, 1])
        values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
        csr = torch.sparse_csr_tensor(crow_indices, col_indices,
                                      values, dtype=dtype, device=device)
        dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
        self.assertEqual(csr.to_dense(), dense)
    def test_sparse_csr_constructor(self):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]

        sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices,
                                                      dtype=torch.int32),
                                         torch.tensor(col_indices,
                                                      dtype=torch.int32),
                                         torch.tensor(values),
                                         size=(2, 10),
                                         dtype=torch.float)

        self.assertEqual((2, 10), sparse.shape)
        self.assertEqual(torch.tensor(crow_indices, dtype=torch.int32),
                         sparse.crow_indices())
Exemple #16
0
def _torch_from_scipy(a):
    """convert a scipy sparse matrix to a torch sparse matrix"""
    if sparse.isspmatrix_csr(a):
        return torch.sparse_csr_tensor(
            crow_indices=a.indptr.astype(
                np.int32 if USE_INT32 else np.int64).copy(),
            col_indices=a.indices.astype(
                np.int32 if USE_INT32 else np.int64).copy(),
            values=a.data.copy(),
            size=a.shape)
    elif sparse.isspmatrix_coo(a):
        return torch.sparse_coo_tensor(indices=(a.row.copy(), a.col.copy()),
                                       values=a.data.copy(),
                                       size=a.shape)
    else:
        raise ValueError('invalid input type {} for _torch_from_scipy'.format(
            type(a)))
Exemple #17
0
    def test_dense_convert(self, device, dtype):
        size = (5, 5)
        dense = torch.randn(size, dtype=dtype, device=device)
        sparse = dense.to_sparse_csr()
        self.assertEqual(sparse.to_dense(), dense)

        size = (4, 6)
        dense = torch.randn(size, dtype=dtype, device=device)
        sparse = dense.to_sparse_csr()
        self.assertEqual(sparse.to_dense(), dense)

        crow_indices = torch.tensor([0, 3, 5])
        col_indices = torch.tensor([0, 1, 2, 0, 1])
        values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
        csr = torch.sparse_csr_tensor(crow_indices, col_indices,
                                      values, dtype=dtype, device=device)
        dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
        self.assertEqual(csr.to_dense(), dense)
Exemple #18
0
def _unary_helper(fn, args, kwargs, inplace):
    if len(kwargs) != 0:
        raise ValueError(
            "MaskedTensor unary ops require that len(kwargs) == 0. "
            "If you need support for this, please open an issue on Github.")
    for a in args[1:]:
        if torch.is_tensor(a):
            raise TypeError(
                "MaskedTensor unary ops do not support additional Tensor arguments"
            )

    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs,
                                                 lambda x: x._masked_mask)
    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs,
                                                 lambda x: x._masked_data)

    if args[0].layout() == torch.sparse_coo:
        data_args[0] = data_args[0].coalesce()
        s = data_args[0].size()
        i = data_args[0].indices()
        data_args[0] = data_args[0].coalesce().values()
        v = fn(*data_args)
        result_data = torch.sparse_coo_tensor(i, v, size=s)

    elif args[0].layout() == torch.sparse_csr:
        crow = data_args[0].crow_indices()
        col = data_args[0].col_indices()
        data_args[0] = data_args[0].values()
        v = fn(*data_args)
        result_data = torch.sparse_csr_tensor(crow, col, v)

    else:
        result_data = fn(*data_args)

    if inplace:
        args[0]._set_data_mask(result_data, mask_args[0])
        return args[0]
    else:
        return _wrap_result(result_data, mask_args[0])
    def test_factory_layout_invariants_check(self, device):
        with self.assertRaisesRegex(RuntimeError, "expected values to be a strided and contiguous tensor"):
            values = torch.tensor([1.], device=device).expand(4,)
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4], device=device),
                                    torch.tensor([0, 1, 0, 1], device=device),
                                    values)

        with self.assertRaisesRegex(RuntimeError, "expected col_indices to be a strided and contiguous tensor"):
            col_indices = torch.tensor([0], device=device).expand(4,)
            torch.sparse_csr_tensor(torch.tensor([0, 2, 4]),
                                    col_indices,
                                    torch.tensor([1, 2, 3, 4]))

        with self.assertRaisesRegex(RuntimeError, "expected crow_indices to be a strided and contiguous tensor"):
            crow_indices = torch.arange(6, device=device)
            torch.sparse_csr_tensor(crow_indices[::2],
                                    torch.tensor([0, 1, 0, 1], device=device),
                                    torch.tensor([1, 2, 3, 4]))
Exemple #20
0
 def test_sparse_csr_print(self, device):
     orig_maxDiff = self.maxDiff
     self.maxDiff = None
     shape_nnz = [((10, 10), 10), ((100, 10), 10), ((1000, 10), 10)]
     printed = []
     for shape, nnz in shape_nnz:
         values_shape = torch.Size((nnz, ))
         col_indices_shape = torch.Size((nnz, ))
         crow_indices_shape = torch.Size((shape[0] + 1, ))
         printed.append("# shape: {}".format(torch.Size(shape)))
         printed.append("# nnz: {}".format(nnz))
         printed.append(
             "# crow_indices shape: {}".format(crow_indices_shape))
         printed.append("# col_indices shape: {}".format(col_indices_shape))
         printed.append("# values_shape: {}".format(values_shape))
         for index_dtype in [torch.int32, torch.int64]:
             for dtype in torch.testing.floating_types():
                 printed.append("########## {}/{} ##########".format(
                     dtype, index_dtype))
                 x = torch.sparse_csr_tensor(
                     torch.tensor([0, 2, 4], dtype=index_dtype),
                     torch.tensor([0, 1, 0, 1], dtype=index_dtype),
                     torch.tensor([1, 2, 3, 4]),
                     dtype=dtype,
                     device=device)
                 printed.append("# sparse tensor")
                 printed.append(str(x))
                 printed.append("# _crow_indices")
                 printed.append(str(x.crow_indices()))
                 printed.append("# _col_indices")
                 printed.append(str(x.col_indices()))
                 printed.append("# _values")
                 printed.append(str(x.values()))
                 printed.append('')
             printed.append('')
     self.assertExpected('\n'.join(printed))
     self.maxDiff = orig_maxDiff
Exemple #21
0
def _binary_helper(fn, args, kwargs, inplace):
    if len(kwargs) != 0:
        raise ValueError("len(kwargs) must equal 0")
    for a in args[2:]:
        if torch.is_tensor(a):
            raise TypeError(
                "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
            )

    if not _masks_match(*args[:2]):
        raise ValueError(
            "Input masks must match. If you need support for this, please open an issue on Github."
        )

    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs,
                                                 lambda x: x.get_data())
    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs,
                                                 lambda x: x.get_mask())

    args0_layout = data_args[0].layout
    same_layout = args0_layout == data_args[1].layout

    if args0_layout == torch.sparse_coo:
        if same_layout:
            if not _tensors_match(data_args[0].indices(),
                                  data_args[1].indices()):
                raise ValueError(
                    "sparse_coo indices must match. If you need support for this, please open an issue on Github."
                )
            if data_args[0].size() != data_args[1].size():
                raise ValueError(
                    "input1 and input2 must have the same size for binary functions."
                )

            data_args[1] = data_args[1].values()

        i = data_args[0].indices()
        size = data_args[0].size()
        data_args[0] = data_args[0].values()
        v = fn(*data_args)
        result_data = torch.sparse_coo_tensor(i, v, size)

    elif args0_layout == torch.sparse_csr:
        if same_layout:
            if not (_tensors_match(data_args[0].crow_indices(),
                                   data_args[1].crow_indices())
                    and _tensors_match(data_args[0].col_indices(),
                                       data_args[1].col_indices())):
                raise ValueError(
                    "sparse_csr indices must match. If you need support for this, please open an issue on Github."
                )

            data_args[1] = data_args[1].values()

        crow = data_args[0].crow_indices()
        col = data_args[0].col_indices()
        data_args[0] = data_args[0].values()
        v = fn(*data_args)
        result_data = torch.sparse_csr_tensor(crow, col, v)

    else:
        result_data = fn(*data_args)

    if inplace:
        args[0]._set_data_mask(result_data, mask_args[0])
        return args[0]
    else:
        result_mask = _get_at_least_one_mask(*args[:2])
        # sparse tensors don't have strides so we can only expand if the layout is strided
        if args0_layout == torch.strided:
            result_mask = result_mask.expand_as(result_data)
        return _wrap_result(result_data, result_mask)
Exemple #22
0
    def test_where(self, sparse_kind, fill_value):

        is_hybrid = False
        if sparse_kind == 'coo':

            def to_sparse(dense):
                return dense.to_sparse(2)

            def set_values(sparse, index, value):
                sparse._values()[index] = value

        elif sparse_kind == 'hybrid_coo':
            is_hybrid = True

            def to_sparse(dense):
                return dense.to_sparse(1)

            def set_values(sparse, index, value):
                sparse._values()[index] = value

        elif sparse_kind == 'csr':

            def to_sparse(dense):
                return dense.to_sparse_csr()

            def set_values(sparse, index, value):
                sparse.values()[index] = value

        else:
            assert 0, sparse_kind

        mask = torch.tensor([[1, 0, 1, 0, 0], [1, 1, 1, 1, 0], [0, 1, 0, 1, 0],
                             [0, 0, 0, 0, 0], [0, 0, 1, 1, 0],
                             [1, 1, 0, 0, 0]]).to(dtype=bool)
        mask = to_sparse(mask)
        # make some specified mask elements as explicit masked-out masks:
        if is_hybrid:
            set_values(mask, (1, 1), False)
            set_values(mask, (-2, -2), False)
        else:
            set_values(mask, 3, False)
            set_values(mask, -3, False)

        input = torch.tensor([[1, 0, 0, 0, -1], [2, 3, 0, 0, -2],
                              [0, 4, 5, 0, -3], [0, 0, 6, 7, 0],
                              [0, 8, 9, 0, -3], [10, 11, 0, 0, -5]])
        input = to_sparse(input)
        # make specified input elements have zero values:
        if is_hybrid:
            set_values(input, (1, 1), 0)
            set_values(input, (-1, 0), 0)
            F = fill_value
        else:
            set_values(input, 3, 0)
            set_values(input, -3, 0)
            F = 0

        # expected where result:
        Z = 99
        # Z value corresponds to masked-in elements that are not
        # specified in the input and it will be replaced with a zero
        tmp = torch.tensor([[1, F, Z, F, F], [2, F, Z, Z, F], [F, 4, F, Z, F],
                            [0, 0, 0, 0, 0], [F, F, 9, F, F], [Z, 11, F, F,
                                                               F]])
        tmp = to_sparse(tmp)

        sparse = torch._masked._where(
            mask, input,
            torch.tensor(fill_value, dtype=input.dtype, device=input.device))

        if tmp.layout == torch.sparse_coo:
            expected_sparse = torch.sparse_coo_tensor(
                tmp.indices(),
                torch.where(tmp.values() != Z, tmp.values(),
                            tmp.values().new_full([], 0)), input.shape)
            outmask = torch.sparse_coo_tensor(
                sparse.indices(),
                sparse.values().new_full(sparse.values().shape,
                                         1).to(dtype=bool),
                sparse.shape)._coalesced_(True)
        elif tmp.layout == torch.sparse_csr:
            expected_sparse = torch.sparse_csr_tensor(
                tmp.crow_indices(), tmp.col_indices(),
                torch.where(tmp.values() != Z, tmp.values(),
                            tmp.values().new_full([], 0)), input.shape)
            outmask = torch.sparse_csr_tensor(
                sparse.crow_indices(), sparse.col_indices(),
                sparse.values().new_full(sparse.values().shape,
                                         1).to(dtype=bool), sparse.shape)
        else:
            assert 0

        self.assertEqual(sparse, expected_sparse)

        # check invariance:
        #  torch.where(mask.to_dense(), input.to_dense(), fill_value)
        #    == where(mask, input, fill_value).to_dense(fill_value)
        expected = torch.where(mask.to_dense(), input.to_dense(),
                               torch.full(input.shape, F))
        dense = torch.where(outmask.to_dense(), sparse.to_dense(),
                            torch.full(sparse.shape, F))
        self.assertEqual(dense, expected)
    def test_factory_shape_invariants_check(self, device):
        crow_indices = [0, 2, 4]
        col_indices = [0, 1, 0, 1]
        values = [1, 2, 3, 4]
        size = (2, 10)
        torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
                                device=device)

        with self.assertRaisesRegex(RuntimeError, r"size of a CSR tensor must be of length 2, but got: 3"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values),
                                    size=(2, 10, 2),
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim\=1 but got crow_indices\.dim\(\)\=2"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices).repeat(2, 1),
                                    torch.tensor(col_indices),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim\=1 but got col_indices\.dim\(\)\=2"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices).repeat(2, 1),
                                    torch.tensor(values),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError, r"values must have dim\=1 but got values\.dim\(\)\=2"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices),
                                    torch.tensor(col_indices),
                                    torch.tensor(values).repeat(2, 1),
                                    size,
                                    device=device)

        with self.assertRaisesRegex(RuntimeError,
                                    r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
                                    device=device)


        with self.assertRaisesRegex(RuntimeError,
                                    r"col_indices and values must have equal sizes, " +
                                    r"but got col_indices\.numel\(\): 3, values\.numel\(\): 4"):
            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 1, 0]), torch.tensor(values), size,
                                    device=device)