예제 #1
0
    def test_right_orthonormalization(self):
        """test right-orthonormalization"""

        # construct non-operator tensor train
        cores = [self.cores[i][:, :, 0:1, :] for i in range(self.order)]
        t_col = TT(cores)

        # right-orthonormalize t
        t_right = t_col.ortho_right()

        # test if cores are right-orthonormal
        err = 0
        for i in range(1, self.order):
            c = np.tensordot(t_right.cores[i],
                             t_right.cores[i],
                             axes=([1, 3], [1, 3])).squeeze()
            if np.linalg.norm(c - np.eye(t_right.ranks[i])) > self.tol:
                err += 1

        # convert t_col to full format and flatten
        t_full = t_col.full().flatten()

        # compute relative error
        rel_err = np.linalg.norm(t_right.full().flatten() -
                                 t_full) / np.linalg.norm(t_full)

        # check if t_right is right-orthonormal and equal to t_col
        self.assertEqual(err, 0)
        self.assertLess(rel_err, self.tol)
예제 #2
0
    def test_right_orthonormalization(self):
        """test right-orthonormalization"""

        # construct non-operator tensor train
        cores = [self.cores[i][:, :, 0:1, :] for i in range(self.order)]
        t_col = TT(cores)

        # right-orthonormalize t
        t_right = t_col.ortho_right(threshold=1e-14)

        # test if cores are right-orthonormal
        err = 0
        for i in range(1, self.order):
            c = np.tensordot(t_right.cores[i],
                             t_right.cores[i],
                             axes=([1, 3], [1, 3])).squeeze()
            if np.linalg.norm(c - np.eye(t_right.ranks[i])) > self.tol:
                err += 1

        # convert t_col to full format and flatten
        t_full = t_col.full().flatten()

        # compute relative error
        rel_err = np.linalg.norm(t_right.full().flatten() -
                                 t_full) / np.linalg.norm(t_full)

        # check if t_right is right-orthonormal and equal to t_col
        self.assertEqual(err, 0)
        self.assertLess(rel_err, self.tol)

        # check if orthonormalization fails if maximum rank is not positive
        with self.assertRaises(ValueError):
            t_col.ortho_right(max_rank=0)

        # check if orthonormalization fails if threshold is negative
        with self.assertRaises(ValueError):
            t_col.ortho_right(threshold=-1)

        # check if orthonormalization fails if start and end indices are not integers
        with self.assertRaises(TypeError):
            t_col.ortho_right(start_index="a")
            t_col.ortho_right(end_index="b")