def cg(A, b, x0, out=None): """ Conjugate gradients method for solving a system of linear equations Ax = b Parameters ---------- A : ht.DNDarray 2D symmetric, positive definite Matrix b : ht.DNDarray 1D vector x0 : ht.DNDarray Arbitrary 1D starting vector out : ht.DNDarray, optional Output Vector Returns ------- ht.DNDarray Returns the solution x of the system of linear equations. If out is given, it is returned """ if (not isinstance(A, ht.DNDarray) or not isinstance(b, ht.DNDarray) or not isinstance(x0, ht.DNDarray)): raise TypeError( "A, b and x0 need to be of type ht.dndarra, but were {}, {}, {}". format(type(A), type(b), type(x0))) if not A.numdims == 2: raise RuntimeError("A needs to be a 2D matrix") if not b.numdims == 1: raise RuntimeError("b needs to be a 1D vector") if not x0.numdims == 1: raise RuntimeError("c needs to be a 1D vector") r = b - ht.matmul(A, x0) p = r rsold = ht.matmul(r, r) x = x0 for i in range(len(b)): Ap = ht.matmul(A, p) alpha = rsold / ht.matmul(p, Ap) x = x + alpha * p r = r - alpha * Ap rsnew = ht.matmul(r, r) if ht.sqrt(rsnew).item() < 1e-10: print("Residual reaches tolerance in it = {}".format(i)) if out is not None: out = x return out return x p = r + ((rsnew / rsold) * p) rsold = rsnew if out is not None: out = x return out return x
def _spectral_embedding(self, X): """ Helper function to embed the dataset X into the eigenvectors of the graph Laplacian matrix Returns ------- ht.DNDarray, shape=(m_lanczos): Eigenvalues of the graph's Laplacian matrix. ht.DNDarray, shape=(n, m_lanczos): Eigenvectors of the graph's Laplacian matrix. """ L = self._laplacian.construct(X) # 3. Eigenvalue and -vector calculation via Lanczos Algorithm v0 = ht.full( (L.shape[0], ), fill_value=1.0 / math.sqrt(L.shape[0]), dtype=L.dtype, split=0, device=L.device, ) V, T = ht.lanczos(L, self.n_lanczos, v0) # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T eval, evec = torch.eig(T._DNDarray__array, eigenvectors=True) # If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L eval, idx = torch.sort(eval[:, 0], dim=0) eigenvalues = ht.array(eval) eigenvectors = ht.matmul(V, ht.array(evec))[:, idx] return eigenvalues, eigenvectors
def _spectral_embedding(self, x: DNDarray) -> Tuple[DNDarray, DNDarray]: """ Helper function for dataset x embedding. Returns Tupel(Eigenvalues, Eigenvectors) of the graph's Laplacian matrix. Parameters ---------- x : DNDarray Sample Matrix for which the embedding should be calculated """ L = self._laplacian.construct(x) # 3. Eigenvalue and -vector calculation via Lanczos Algorithm v0 = ht.full( (L.shape[0],), fill_value=1.0 / math.sqrt(L.shape[0]), dtype=L.dtype, split=0, device=L.device, ) V, T = ht.lanczos(L, self.n_lanczos, v0) # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T eval, evec = torch.eig(T.larray, eigenvectors=True) # If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L eval, idx = torch.sort(eval[:, 0], dim=0) eigenvalues = ht.array(eval) eigenvectors = ht.matmul(V, ht.array(evec))[:, idx] return eigenvalues, eigenvectors
def _spectral_embedding(self, x: DNDarray) -> Tuple[DNDarray, DNDarray]: """ Helper function for dataset x embedding. Returns Tupel(Eigenvalues, Eigenvectors) of the graph's Laplacian matrix. Parameters ---------- x : DNDarray Sample Matrix for which the embedding should be calculated Notes ----- This will throw out the complex side of the eigenvalues found during this. """ L = self._laplacian.construct(x) # 3. Eigenvalue and -vector calculation via Lanczos Algorithm v0 = ht.full( (L.shape[0], ), fill_value=1.0 / math.sqrt(L.shape[0]), dtype=L.dtype, split=0, device=L.device, ) V, T = ht.lanczos(L, self.n_lanczos, v0) # if int(torch.__version__.split(".")[1]) >= 9: try: # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T eval, evec = torch.linalg.eig(T.larray) # If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L eval, idx = torch.sort(eval.real, dim=0) eigenvalues = ht.array(eval) eigenvectors = ht.matmul(V, ht.array(evec))[:, idx] return eigenvalues.real, eigenvectors.real except AttributeError: # torch version is less than 1.9.0 # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T eval, evec = torch.eig(T.larray, eigenvectors=True) # If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L eval, idx = torch.sort(eval[:, 0], dim=0) eigenvalues = ht.array(eval) eigenvectors = ht.matmul(V, ht.array(evec))[:, idx] return eigenvalues, eigenvectors
def lanczos( A: DNDarray, m: int, v0: Optional[DNDarray] = None, V_out: Optional[DNDarray] = None, T_out: Optional[DNDarray] = None, ) -> Tuple[DNDarray, DNDarray]: r""" The Lanczos algorithm is an iterative approximation of the solution to the eigenvalue problem, as an adaptation of power methods to find the m "most useful" (tending towards extreme highest/lowest) eigenvalues and eigenvectors of an :math:`n \times n` Hermitian matrix, where often :math:`m<<n`. It returns two matrices :math:`V` and :math:`T`, where: - :math:`V` is a Matrix of size :math:`n\times m`, with orthonormal columns, that span the Krylow subspace \n - :math:`T` is a Tridiagonal matrix of size :math:`m\times m`, with coefficients :math:`\alpha_1,..., \alpha_n` on the diagonal and coefficients :math:`\beta_1,...,\beta_{n-1}` on the side-diagonals\n Parameters ---------- A : DNDarray 2D symmetric, positive definite Matrix m : int Number of Lanczos iterations v0 : DNDarray, optional 1D starting vector of Euclidian norm 1. If not provided, a random vector will be used to start the algorithm V_out : DNDarray, optional Output Matrix for the Krylow vectors, Shape = (n, m) T_out : DNDarray, optional Output Matrix for the Tridiagonal matrix, Shape = (m, m) """ if not isinstance(A, DNDarray): raise TypeError("A needs to be of type ht.dndarra, but was {}".format( type(A))) if not (A.ndim == 2): raise RuntimeError("A needs to be a 2D matrix") if not isinstance(m, (int, float)): raise TypeError("m must be eiter int or float, but was {}".format( type(m))) n, column = A.shape if n != column: raise TypeError("Input Matrix A needs to be symmetric.") T = ht.zeros((m, m)) if A.split == 0: # This is done for better memory access in the reorthogonalization Gram-Schmidt algorithm V = ht.ones((n, m), split=0, dtype=A.dtype, device=A.device) else: V = ht.ones((n, m), split=None, dtype=A.dtype, device=A.device) if v0 is None: vr = ht.random.rand(n, split=V.split) v0 = vr / ht.norm(vr) else: if v0.split != V.split: v0.resplit_(axis=V.split) # # 0th iteration # # vector v0 has euklidian norm = 1 w = ht.matmul(A, v0) alpha = ht.dot(w, v0) w = w - alpha * v0 T[0, 0] = alpha V[:, 0] = v0 for i in range(1, int(m)): beta = ht.norm(w) if ht.abs(beta) < 1e-10: # print("Lanczos breakdown in iteration {}".format(i)) # Lanczos Breakdown, pick a random vector to continue vr = ht.random.rand(n, dtype=A.dtype, split=V.split) # orthogonalize v_r with respect to all vectors v[i] for j in range(i): vi_loc = V.larray[:, j] a = torch.dot(vr.larray, vi_loc) b = torch.dot(vi_loc, vi_loc) A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM) A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM) vr.larray = vr.larray - a / b * vi_loc # normalize v_r to Euclidian norm 1 and set as ith vector v vi = vr / ht.norm(vr) else: vr = w # Reorthogonalization # ToDo: Rethink this; mask torch calls, See issue #494 # This is the fast solution, using item access on the ht.dndarray level is way slower for j in range(i): vi_loc = V.larray[:, j] a = torch.dot(vr._DNDarray__array, vi_loc) b = torch.dot(vi_loc, vi_loc) A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM) A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM) vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc vi = vr / ht.norm(vr) w = ht.matmul(A, vi) alpha = ht.dot(w, vi) w = w - alpha * vi - beta * V[:, i - 1] T[i - 1, i] = beta T[i, i - 1] = beta T[i, i] = alpha V[:, i] = vi if V.split is not None: V.resplit_(axis=None) if T_out is not None: T_out = T.copy() if V_out is not None: V_out = V.copy() return V_out, T_out return V, T_out elif V_out is not None: V_out = V.copy() return V_out, T return V, T
def test_matmul(self): with self.assertRaises(ValueError): ht.matmul(ht.ones((25, 25)), ht.ones((42, 42))) # cases to test: n, m = 21, 31 j, k = m, 45 a_torch = torch.ones((n, m), device=self.device.torch_device) a_torch[0] = torch.arange(1, m + 1, device=self.device.torch_device) a_torch[:, -1] = torch.arange(1, n + 1, device=self.device.torch_device) b_torch = torch.ones((j, k), device=self.device.torch_device) b_torch[0] = torch.arange(1, k + 1, device=self.device.torch_device) b_torch[:, 0] = torch.arange(1, j + 1, device=self.device.torch_device) # splits None None a = ht.ones((n, m), split=None) b = ht.ones((j, k), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) self.assertEqual(a.split, None) self.assertEqual(b.split, None) # splits None None a = ht.ones((n, m), split=None) b = ht.ones((j, k), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b, allow_resplit=True) self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) self.assertEqual(a.split, 0) self.assertEqual(b.split, None) if a.comm.size > 1: # splits 00 a = ht.ones((n, m), split=0, dtype=ht.float64) b = ht.ones((j, k), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = a @ b ret_comp00 = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp00)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float64) self.assertEqual(ret00.split, 0) # splits 00 (numpy) a = ht.array(np.ones((n, m)), split=0) b = ht.array(np.ones((j, k)), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = a @ b ret_comp00 = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp00)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float64) self.assertEqual(ret00.split, 0) # splits 01 a = ht.ones((n, m), split=0) b = ht.ones((j, k), split=1, dtype=ht.float64) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp01 = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp01)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float64) self.assertEqual(ret00.split, 0) # splits 10 a = ht.ones((n, m), split=1) b = ht.ones((j, k), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp10 = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp10)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # splits 11 a = ht.ones((n, m), split=1) b = ht.ones((j, k), split=1) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp11 = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp11)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # splits 11 (torch) a = ht.array(torch.ones((n, m), device=self.device.torch_device), split=1) b = ht.array(torch.ones((j, k), device=self.device.torch_device), split=1) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp11 = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp11)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # splits 0 None a = ht.ones((n, m), split=0) b = ht.ones((j, k), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp0 = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp0)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 1 None a = ht.ones((n, m), split=1) b = ht.ones((j, k), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp1 = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp1)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # splits None 0 a = ht.ones((n, m), split=None) b = ht.ones((j, k), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits None 1 a = ht.ones((n, m), split=None) b = ht.ones((j, k), split=1) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=1) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 1) # vector matrix mult: # a -> vector a_torch = torch.ones((m), device=self.device.torch_device) b_torch = torch.ones((j, k), device=self.device.torch_device) b_torch[0] = torch.arange(1, k + 1, device=self.device.torch_device) b_torch[:, 0] = torch.arange(1, j + 1, device=self.device.torch_device) # splits None None a = ht.ones((m), split=None) b = ht.ones((j, k), split=None) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) # splits None 0 a = ht.ones((m), split=None) b = ht.ones((j, k), split=0) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits None 1 a = ht.ones((m), split=None) b = ht.ones((j, k), split=1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=0) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 0 None a = ht.ones((m), split=None) b = ht.ones((j, k), split=0) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 0 0 a = ht.ones((m), split=0) b = ht.ones((j, k), split=0) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 0 1 a = ht.ones((m), split=0) b = ht.ones((j, k), split=1) b[0] = ht.arange(1, k + 1) b[:, 0] = ht.arange(1, j + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # b -> vector a_torch = torch.ones((n, m), device=self.device.torch_device) a_torch[0] = torch.arange(1, m + 1, device=self.device.torch_device) a_torch[:, -1] = torch.arange(1, n + 1, device=self.device.torch_device) b_torch = torch.ones((j), device=self.device.torch_device) # splits None None a = ht.ones((n, m), split=None) b = ht.ones((j), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) # splits 0 None a = ht.ones((n, m), split=0) b = ht.ones((j), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 1 None a = ht.ones((n, m), split=1) b = ht.ones((j), split=None) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits None 0 a = ht.ones((n, m), split=None) b = ht.ones((j), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 0 0 a = ht.ones((n, m), split=0) b = ht.ones((j), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) # splits 1 0 a = ht.ones((n, m), split=1) b = ht.ones((j), split=0) a[0] = ht.arange(1, m + 1) a[:, -1] = ht.arange(1, n + 1) ret00 = ht.matmul(a, b) ret_comp = ht.array((a_torch @ b_torch), split=None) self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (n, )) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, 0) with self.assertRaises(NotImplementedError): a = ht.zeros((3, 3, 3), split=2) b = a.copy() a @ b