예제 #1
0
def cg(A, b, x0, out=None):
    """
    Conjugate gradients method for solving a system of linear equations Ax = b

    Parameters
    ----------
    A : ht.DNDarray
        2D symmetric, positive definite Matrix
    b : ht.DNDarray
        1D vector
    x0 : ht.DNDarray
        Arbitrary 1D starting vector
    out : ht.DNDarray, optional
        Output Vector


    Returns
    -------
    ht.DNDarray
        Returns the solution x of the system of linear equations. If out is given, it is returned
    """

    if (not isinstance(A, ht.DNDarray) or not isinstance(b, ht.DNDarray)
            or not isinstance(x0, ht.DNDarray)):
        raise TypeError(
            "A, b and x0 need to be of type ht.dndarra, but were {}, {}, {}".
            format(type(A), type(b), type(x0)))

    if not A.numdims == 2:
        raise RuntimeError("A needs to be a 2D matrix")
    if not b.numdims == 1:
        raise RuntimeError("b needs to be a 1D vector")
    if not x0.numdims == 1:
        raise RuntimeError("c needs to be a 1D vector")

    r = b - ht.matmul(A, x0)
    p = r
    rsold = ht.matmul(r, r)
    x = x0

    for i in range(len(b)):
        Ap = ht.matmul(A, p)
        alpha = rsold / ht.matmul(p, Ap)
        x = x + alpha * p
        r = r - alpha * Ap
        rsnew = ht.matmul(r, r)
        if ht.sqrt(rsnew).item() < 1e-10:
            print("Residual reaches tolerance in it = {}".format(i))
            if out is not None:
                out = x
                return out
            return x

        p = r + ((rsnew / rsold) * p)
        rsold = rsnew

    if out is not None:
        out = x
        return out
    return x
예제 #2
0
파일: spectral.py 프로젝트: mtar/heat
    def _spectral_embedding(self, X):
        """
        Helper function to embed the dataset X into the eigenvectors of the graph Laplacian matrix
        Returns
        -------
        ht.DNDarray, shape=(m_lanczos):
            Eigenvalues of the graph's Laplacian matrix.
        ht.DNDarray, shape=(n, m_lanczos):
            Eigenvectors of the graph's Laplacian matrix.
        """
        L = self._laplacian.construct(X)
        # 3. Eigenvalue and -vector calculation via Lanczos Algorithm
        v0 = ht.full(
            (L.shape[0], ),
            fill_value=1.0 / math.sqrt(L.shape[0]),
            dtype=L.dtype,
            split=0,
            device=L.device,
        )
        V, T = ht.lanczos(L, self.n_lanczos, v0)

        # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T
        eval, evec = torch.eig(T._DNDarray__array, eigenvectors=True)
        # If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L
        eval, idx = torch.sort(eval[:, 0], dim=0)
        eigenvalues = ht.array(eval)
        eigenvectors = ht.matmul(V, ht.array(evec))[:, idx]

        return eigenvalues, eigenvectors
예제 #3
0
    def _spectral_embedding(self, x: DNDarray) -> Tuple[DNDarray, DNDarray]:
        """
        Helper function for dataset x embedding.
        Returns Tupel(Eigenvalues, Eigenvectors) of the graph's Laplacian matrix.

        Parameters
        ----------
        x : DNDarray
            Sample Matrix for which the embedding should be calculated

        """
        L = self._laplacian.construct(x)
        # 3. Eigenvalue and -vector calculation via Lanczos Algorithm
        v0 = ht.full(
            (L.shape[0],),
            fill_value=1.0 / math.sqrt(L.shape[0]),
            dtype=L.dtype,
            split=0,
            device=L.device,
        )
        V, T = ht.lanczos(L, self.n_lanczos, v0)

        # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T
        eval, evec = torch.eig(T.larray, eigenvectors=True)
        # If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L
        eval, idx = torch.sort(eval[:, 0], dim=0)
        eigenvalues = ht.array(eval)
        eigenvectors = ht.matmul(V, ht.array(evec))[:, idx]

        return eigenvalues, eigenvectors
예제 #4
0
    def _spectral_embedding(self, x: DNDarray) -> Tuple[DNDarray, DNDarray]:
        """
        Helper function for dataset x embedding.
        Returns Tupel(Eigenvalues, Eigenvectors) of the graph's Laplacian matrix.

        Parameters
        ----------
        x : DNDarray
            Sample Matrix for which the embedding should be calculated

        Notes
        -----
        This will throw out the complex side of the eigenvalues found during this.

        """
        L = self._laplacian.construct(x)
        # 3. Eigenvalue and -vector calculation via Lanczos Algorithm
        v0 = ht.full(
            (L.shape[0], ),
            fill_value=1.0 / math.sqrt(L.shape[0]),
            dtype=L.dtype,
            split=0,
            device=L.device,
        )
        V, T = ht.lanczos(L, self.n_lanczos, v0)

        # if int(torch.__version__.split(".")[1]) >= 9:
        try:
            # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T
            eval, evec = torch.linalg.eig(T.larray)

            # If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L
            eval, idx = torch.sort(eval.real, dim=0)
            eigenvalues = ht.array(eval)
            eigenvectors = ht.matmul(V, ht.array(evec))[:, idx]

            return eigenvalues.real, eigenvectors.real
        except AttributeError:  # torch version is less than 1.9.0
            # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T
            eval, evec = torch.eig(T.larray, eigenvectors=True)
            # If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L
            eval, idx = torch.sort(eval[:, 0], dim=0)
            eigenvalues = ht.array(eval)
            eigenvectors = ht.matmul(V, ht.array(evec))[:, idx]

            return eigenvalues, eigenvectors
예제 #5
0
def lanczos(
    A: DNDarray,
    m: int,
    v0: Optional[DNDarray] = None,
    V_out: Optional[DNDarray] = None,
    T_out: Optional[DNDarray] = None,
) -> Tuple[DNDarray, DNDarray]:
    r"""
    The Lanczos algorithm is an iterative approximation of the solution to the eigenvalue problem, as an adaptation of
    power methods to find the m "most useful" (tending towards extreme highest/lowest) eigenvalues and eigenvectors of
    an :math:`n \times n` Hermitian matrix, where often :math:`m<<n`.
    It returns two matrices :math:`V` and :math:`T`, where:

        - :math:`V` is a Matrix of size :math:`n\times m`, with orthonormal columns, that span the Krylow subspace \n
        - :math:`T` is a Tridiagonal matrix of size :math:`m\times m`, with coefficients :math:`\alpha_1,..., \alpha_n`
          on the diagonal and coefficients :math:`\beta_1,...,\beta_{n-1}` on the side-diagonals\n

    Parameters
    ----------
    A : DNDarray
        2D symmetric, positive definite Matrix
    m : int
        Number of Lanczos iterations
    v0 : DNDarray, optional
        1D starting vector of Euclidian norm 1. If not provided, a random vector will be used to start the algorithm
    V_out : DNDarray, optional
        Output Matrix for the Krylow vectors, Shape = (n, m)
    T_out : DNDarray, optional
        Output Matrix for the Tridiagonal matrix, Shape = (m, m)
    """
    if not isinstance(A, DNDarray):
        raise TypeError("A needs to be of type ht.dndarra, but was {}".format(
            type(A)))

    if not (A.ndim == 2):
        raise RuntimeError("A needs to be a 2D matrix")
    if not isinstance(m, (int, float)):
        raise TypeError("m must be eiter int or float, but was {}".format(
            type(m)))

    n, column = A.shape
    if n != column:
        raise TypeError("Input Matrix A needs to be symmetric.")
    T = ht.zeros((m, m))
    if A.split == 0:
        # This is done for better memory access in the reorthogonalization Gram-Schmidt algorithm
        V = ht.ones((n, m), split=0, dtype=A.dtype, device=A.device)
    else:
        V = ht.ones((n, m), split=None, dtype=A.dtype, device=A.device)

    if v0 is None:
        vr = ht.random.rand(n, split=V.split)
        v0 = vr / ht.norm(vr)
    else:
        if v0.split != V.split:
            v0.resplit_(axis=V.split)
    # # 0th iteration
    # # vector v0 has euklidian norm = 1
    w = ht.matmul(A, v0)
    alpha = ht.dot(w, v0)
    w = w - alpha * v0
    T[0, 0] = alpha
    V[:, 0] = v0
    for i in range(1, int(m)):
        beta = ht.norm(w)
        if ht.abs(beta) < 1e-10:
            # print("Lanczos breakdown in iteration {}".format(i))
            # Lanczos Breakdown, pick a random vector to continue
            vr = ht.random.rand(n, dtype=A.dtype, split=V.split)
            # orthogonalize v_r with respect to all vectors v[i]
            for j in range(i):
                vi_loc = V.larray[:, j]
                a = torch.dot(vr.larray, vi_loc)
                b = torch.dot(vi_loc, vi_loc)
                A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a,
                                 ht.communication.MPI.SUM)
                A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b,
                                 ht.communication.MPI.SUM)
                vr.larray = vr.larray - a / b * vi_loc
            # normalize v_r to Euclidian norm 1 and set as ith vector v
            vi = vr / ht.norm(vr)
        else:
            vr = w

            # Reorthogonalization
            # ToDo: Rethink this; mask torch calls, See issue #494
            # This is the fast solution, using item access on the ht.dndarray level is way slower
            for j in range(i):
                vi_loc = V.larray[:, j]
                a = torch.dot(vr._DNDarray__array, vi_loc)
                b = torch.dot(vi_loc, vi_loc)
                A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a,
                                 ht.communication.MPI.SUM)
                A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b,
                                 ht.communication.MPI.SUM)
                vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc

            vi = vr / ht.norm(vr)

        w = ht.matmul(A, vi)
        alpha = ht.dot(w, vi)

        w = w - alpha * vi - beta * V[:, i - 1]

        T[i - 1, i] = beta
        T[i, i - 1] = beta
        T[i, i] = alpha
        V[:, i] = vi

    if V.split is not None:
        V.resplit_(axis=None)

    if T_out is not None:
        T_out = T.copy()
        if V_out is not None:
            V_out = V.copy()
            return V_out, T_out
        return V, T_out
    elif V_out is not None:
        V_out = V.copy()
        return V_out, T

    return V, T
예제 #6
0
파일: test_basics.py 프로젝트: mtar/heat
    def test_matmul(self):
        with self.assertRaises(ValueError):
            ht.matmul(ht.ones((25, 25)), ht.ones((42, 42)))

        # cases to test:
        n, m = 21, 31
        j, k = m, 45
        a_torch = torch.ones((n, m), device=self.device.torch_device)
        a_torch[0] = torch.arange(1, m + 1, device=self.device.torch_device)
        a_torch[:, -1] = torch.arange(1,
                                      n + 1,
                                      device=self.device.torch_device)
        b_torch = torch.ones((j, k), device=self.device.torch_device)
        b_torch[0] = torch.arange(1, k + 1, device=self.device.torch_device)
        b_torch[:, 0] = torch.arange(1, j + 1, device=self.device.torch_device)

        # splits None None
        a = ht.ones((n, m), split=None)
        b = ht.ones((j, k), split=None)
        a[0] = ht.arange(1, m + 1)
        a[:, -1] = ht.arange(1, n + 1)
        b[0] = ht.arange(1, k + 1)
        b[:, 0] = ht.arange(1, j + 1)
        ret00 = ht.matmul(a, b)

        self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1)
        self.assertIsInstance(ret00, ht.DNDarray)
        self.assertEqual(ret00.shape, (n, k))
        self.assertEqual(ret00.dtype, ht.float)
        self.assertEqual(ret00.split, None)
        self.assertEqual(a.split, None)
        self.assertEqual(b.split, None)

        # splits None None
        a = ht.ones((n, m), split=None)
        b = ht.ones((j, k), split=None)
        a[0] = ht.arange(1, m + 1)
        a[:, -1] = ht.arange(1, n + 1)
        b[0] = ht.arange(1, k + 1)
        b[:, 0] = ht.arange(1, j + 1)
        ret00 = ht.matmul(a, b, allow_resplit=True)

        self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1)
        self.assertIsInstance(ret00, ht.DNDarray)
        self.assertEqual(ret00.shape, (n, k))
        self.assertEqual(ret00.dtype, ht.float)
        self.assertEqual(ret00.split, None)
        self.assertEqual(a.split, 0)
        self.assertEqual(b.split, None)

        if a.comm.size > 1:
            # splits 00
            a = ht.ones((n, m), split=0, dtype=ht.float64)
            b = ht.ones((j, k), split=0)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = a @ b

            ret_comp00 = ht.array(a_torch @ b_torch, split=0)
            self.assertTrue(ht.equal(ret00, ret_comp00))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float64)
            self.assertEqual(ret00.split, 0)

            # splits 00 (numpy)
            a = ht.array(np.ones((n, m)), split=0)
            b = ht.array(np.ones((j, k)), split=0)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = a @ b

            ret_comp00 = ht.array(a_torch @ b_torch, split=0)
            self.assertTrue(ht.equal(ret00, ret_comp00))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float64)
            self.assertEqual(ret00.split, 0)

            # splits 01
            a = ht.ones((n, m), split=0)
            b = ht.ones((j, k), split=1, dtype=ht.float64)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp01 = ht.array(a_torch @ b_torch, split=0)
            self.assertTrue(ht.equal(ret00, ret_comp01))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float64)
            self.assertEqual(ret00.split, 0)

            # splits 10
            a = ht.ones((n, m), split=1)
            b = ht.ones((j, k), split=0)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp10 = ht.array(a_torch @ b_torch, split=1)
            self.assertTrue(ht.equal(ret00, ret_comp10))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 1)

            # splits 11
            a = ht.ones((n, m), split=1)
            b = ht.ones((j, k), split=1)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp11 = ht.array(a_torch @ b_torch, split=1)
            self.assertTrue(ht.equal(ret00, ret_comp11))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 1)

            # splits 11 (torch)
            a = ht.array(torch.ones((n, m), device=self.device.torch_device),
                         split=1)
            b = ht.array(torch.ones((j, k), device=self.device.torch_device),
                         split=1)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp11 = ht.array(a_torch @ b_torch, split=1)
            self.assertTrue(ht.equal(ret00, ret_comp11))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 1)

            # splits 0 None
            a = ht.ones((n, m), split=0)
            b = ht.ones((j, k), split=None)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp0 = ht.array(a_torch @ b_torch, split=0)
            self.assertTrue(ht.equal(ret00, ret_comp0))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits 1 None
            a = ht.ones((n, m), split=1)
            b = ht.ones((j, k), split=None)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp1 = ht.array(a_torch @ b_torch, split=1)
            self.assertTrue(ht.equal(ret00, ret_comp1))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 1)

            # splits None 0
            a = ht.ones((n, m), split=None)
            b = ht.ones((j, k), split=0)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array(a_torch @ b_torch, split=0)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits None 1
            a = ht.ones((n, m), split=None)
            b = ht.ones((j, k), split=1)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array(a_torch @ b_torch, split=1)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, k))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 1)

            # vector matrix mult:
            # a -> vector
            a_torch = torch.ones((m), device=self.device.torch_device)
            b_torch = torch.ones((j, k), device=self.device.torch_device)
            b_torch[0] = torch.arange(1,
                                      k + 1,
                                      device=self.device.torch_device)
            b_torch[:, 0] = torch.arange(1,
                                         j + 1,
                                         device=self.device.torch_device)
            # splits None None
            a = ht.ones((m), split=None)
            b = ht.ones((j, k), split=None)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array(a_torch @ b_torch, split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (k, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, None)

            # splits None 0
            a = ht.ones((m), split=None)
            b = ht.ones((j, k), split=0)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array(a_torch @ b_torch, split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (k, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits None 1
            a = ht.ones((m), split=None)
            b = ht.ones((j, k), split=1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)
            ret_comp = ht.array(a_torch @ b_torch, split=0)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (k, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits 0 None
            a = ht.ones((m), split=None)
            b = ht.ones((j, k), split=0)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array(a_torch @ b_torch, split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (k, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits 0 0
            a = ht.ones((m), split=0)
            b = ht.ones((j, k), split=0)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array(a_torch @ b_torch, split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (k, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits 0 1
            a = ht.ones((m), split=0)
            b = ht.ones((j, k), split=1)
            b[0] = ht.arange(1, k + 1)
            b[:, 0] = ht.arange(1, j + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array(a_torch @ b_torch, split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (k, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # b -> vector
            a_torch = torch.ones((n, m), device=self.device.torch_device)
            a_torch[0] = torch.arange(1,
                                      m + 1,
                                      device=self.device.torch_device)
            a_torch[:, -1] = torch.arange(1,
                                          n + 1,
                                          device=self.device.torch_device)
            b_torch = torch.ones((j), device=self.device.torch_device)
            # splits None None
            a = ht.ones((n, m), split=None)
            b = ht.ones((j), split=None)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array(a_torch @ b_torch, split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, None)

            # splits 0 None
            a = ht.ones((n, m), split=0)
            b = ht.ones((j), split=None)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array((a_torch @ b_torch), split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits 1 None
            a = ht.ones((n, m), split=1)
            b = ht.ones((j), split=None)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array((a_torch @ b_torch), split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits None 0
            a = ht.ones((n, m), split=None)
            b = ht.ones((j), split=0)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array((a_torch @ b_torch), split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits 0 0
            a = ht.ones((n, m), split=0)
            b = ht.ones((j), split=0)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array((a_torch @ b_torch), split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            # splits 1 0
            a = ht.ones((n, m), split=1)
            b = ht.ones((j), split=0)
            a[0] = ht.arange(1, m + 1)
            a[:, -1] = ht.arange(1, n + 1)
            ret00 = ht.matmul(a, b)

            ret_comp = ht.array((a_torch @ b_torch), split=None)
            self.assertTrue(ht.equal(ret00, ret_comp))
            self.assertIsInstance(ret00, ht.DNDarray)
            self.assertEqual(ret00.shape, (n, ))
            self.assertEqual(ret00.dtype, ht.float)
            self.assertEqual(ret00.split, 0)

            with self.assertRaises(NotImplementedError):
                a = ht.zeros((3, 3, 3), split=2)
                b = a.copy()
                a @ b