Exemplo n.º 1
0
    def test_left_orthonormalization(self):
        """test left-orthonormalization"""

        # construct non-operator tensor train
        cores = [self.cores[i][:, :, 0:1, :] for i in range(self.order)]
        t_col = TT(cores)

        # left-orthonormalize t
        t_left = t_col.ortho_left()

        # test if cores are left-orthonormal
        err = 0
        for i in range(self.order - 1):
            c = np.tensordot(t_left.cores[i],
                             t_left.cores[i],
                             axes=([0, 1], [0, 1])).squeeze()
            if np.linalg.norm(c - np.eye(t_left.ranks[i + 1])) > self.tol:
                err += 1

        # convert t_col to full format and flatten
        t_full = t_col.full().flatten()

        # compute relative error
        rel_err = np.linalg.norm(t_left.full().flatten() -
                                 t_full) / np.linalg.norm(t_full)

        # check if t_left is left-orthonormal and equal to t_col
        self.assertEqual(err, 0)
        self.assertLess(rel_err, self.tol)
Exemplo n.º 2
0
    def test_left_orthonormalization(self):
        """test left-orthonormalization"""

        # construct non-operator tensor train
        cores = [self.cores[i][:, :, 0:1, :] for i in range(self.order)]
        t_col = TT(cores)

        # left-orthonormalize t
        t_left = t_col.ortho_left(threshold=1e-14)

        # test if cores are left-orthonormal
        err = 0
        for i in range(self.order - 1):
            c = np.tensordot(t_left.cores[i],
                             t_left.cores[i],
                             axes=([0, 1], [0, 1])).squeeze()
            if np.linalg.norm(c - np.eye(t_left.ranks[i + 1])) > self.tol:
                err += 1

        # convert t_col to full format and flatten
        t_full = t_col.full().flatten()

        # compute relative error
        rel_err = np.linalg.norm(t_left.full().flatten() -
                                 t_full) / np.linalg.norm(t_full)

        # check if t_left is left-orthonormal and equal to t_col
        self.assertEqual(err, 0)
        self.assertLess(rel_err, self.tol)

        # check if orthonormalization fails if maximum rank is not positive
        with self.assertRaises(ValueError):
            t_col.ortho_left(max_rank=0)

        # check if orthonormalization fails if threshold is negative
        with self.assertRaises(ValueError):
            t_col.ortho_left(threshold=-1)

        # check if orthonormalization fails if start and end indices are not integers
        with self.assertRaises(TypeError):
            t_col.ortho_left(start_index="a")
            t_col.ortho_left(end_index="b")