def test_matmul(self):
        lhs = torch.randn(5, 3, requires_grad=True)
        rhs = torch.randn(3, 4, requires_grad=True)
        covar = MatmulLazyVariable(lhs, rhs)
        mat = torch.randn(4, 10)
        res = covar.matmul(mat)

        lhs_clone = lhs.clone().detach()
        rhs_clone = rhs.clone().detach()
        mat_clone = mat.clone().detach()
        lhs_clone.requires_grad = True
        rhs_clone.requires_grad = True
        mat_clone.requires_grad = True
        actual = lhs_clone.matmul(rhs_clone).matmul(mat_clone)

        self.assertTrue(approx_equal(res, actual))

        actual.sum().backward()

        res.sum().backward()
        self.assertTrue(approx_equal(lhs.grad, lhs_clone.grad))
        self.assertTrue(approx_equal(rhs.grad, rhs_clone.grad))
    def test_batch_diag(self):
        lhs = torch.randn(4, 5, 3)
        rhs = torch.randn(4, 3, 5)
        actual = lhs.matmul(rhs)
        actual_diag = torch.cat([
            actual[0].diag().unsqueeze(0),
            actual[1].diag().unsqueeze(0),
            actual[2].diag().unsqueeze(0),
            actual[3].diag().unsqueeze(0),
        ])

        res = MatmulLazyVariable(lhs, rhs)
        self.assertTrue(approx_equal(actual_diag, res.diag()))
Example #3
0
def test_batch_diag():
    lhs = Variable(torch.randn(4, 5, 3))
    rhs = Variable(torch.randn(4, 3, 5))
    actual = lhs.matmul(rhs)
    actual_diag = torch.cat([
        actual[0].diag().unsqueeze(0),
        actual[1].diag().unsqueeze(0),
        actual[2].diag().unsqueeze(0),
        actual[3].diag().unsqueeze(0),
    ])

    res = MatmulLazyVariable(lhs, rhs)
    assert approx_equal(actual_diag.data, res.diag().data)
    def test_solve(self):
        size = 100
        train_x = torch.linspace(0, 1, size)
        covar_matrix = RBFKernel()(train_x, train_x).evaluate()
        piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10)
        woodbury_factor = pivoted_cholesky.woodbury_factor(piv_chol, torch.ones(100))

        rhs_vector = torch.randn(100, 50)
        shifted_covar_matrix = covar_matrix + torch.eye(size)
        real_solve = shifted_covar_matrix.inverse().matmul(rhs_vector)
        approx_solve = pivoted_cholesky.woodbury_solve(rhs_vector, piv_chol, woodbury_factor, torch.ones(100))

        self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))
Example #5
0
    def test_batch_left_t_interp_on_a_vector(self):
        vector = torch.randn(9)

        actual = torch.matmul(
            self.batch_interp_matrix.transpose(-1, -2),
            vector.unsqueeze(-1).unsqueeze(0),
        ).squeeze(
            -1
        )
        res = left_t_interp(
            self.batch_interp_indices, self.batch_interp_values, Variable(vector), 6
        ).data
        self.assertTrue(approx_equal(res, actual))
Example #6
0
    def test_inv_quad_log_det_many_vectors(self):
        # Forward pass
        actual_inv_quad = torch.cat([
            self.mats_var_clone[0].inverse().unsqueeze(0),
            self.mats_var_clone[1].inverse().unsqueeze(0),
        ]).matmul(self.vecs_var_clone).mul(self.vecs_var_clone).sum(2).sum(1)
        with gpytorch.settings.num_trace_samples(1000):
            nlv = NonLazyVariable(self.mats_var)
            res_inv_quad, res_log_det = nlv.inv_quad_log_det(
                inv_quad_rhs=self.vecs_var, log_det=True)
        for i in range(self.mats_var.size(0)):
            self.assertAlmostEqual(res_inv_quad.data[i],
                                   actual_inv_quad.data[i],
                                   places=1)
            self.assertAlmostEqual(res_log_det.data[i],
                                   self.log_dets[i],
                                   places=1)

        # Backward
        inv_quad_grad_output = torch.Tensor([3, 4])
        log_det_grad_output = torch.Tensor([4, 2])
        actual_inv_quad.backward(gradient=inv_quad_grad_output)
        mat_log_det_grad = torch.cat([
            self.mats_var_clone[0].data.inverse().mul(
                log_det_grad_output[0]).unsqueeze(0),
            self.mats_var_clone[1].data.inverse().mul(
                log_det_grad_output[1]).unsqueeze(0),
        ])
        self.mats_var_clone.grad.data.add_(mat_log_det_grad)
        res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True)
        res_log_det.backward(gradient=log_det_grad_output)

        self.assertTrue(
            approx_equal(self.mats_var_clone.grad.data,
                         self.mats_var.grad.data,
                         epsilon=1e-1))
        self.assertTrue(
            approx_equal(self.vecs_var_clone.grad.data,
                         self.vecs_var.grad.data))
Example #7
0
    def test_log_det_only(self):
        # Forward pass
        with gpytorch.settings.num_trace_samples(1000):
            res = NonLazyVariable(self.mat_var).log_det()
        self.assertAlmostEqual(res.data[0], self.log_det, places=1)

        # Backward
        grad_output = torch.Tensor([3])
        actual_mat_grad = self.mat_var_clone.data.inverse().mul(grad_output)
        res.backward(gradient=grad_output)
        self.assertTrue(
            approx_equal(actual_mat_grad, self.mat_var.grad.data,
                         epsilon=1e-1))
Example #8
0
    def test_toeplitz_matmul_batchmat(self):
        col = torch.tensor([1, 6, 4, 5], dtype=torch.float)
        row = torch.tensor([1, 2, 1, 1], dtype=torch.float)
        rhs_mat = torch.randn(3, 4, 2)

        # Actual
        lhs_mat = utils.toeplitz.toeplitz(col, row)
        actual = torch.matmul(lhs_mat.unsqueeze(0), rhs_mat)

        # Fast toeplitz
        res = utils.toeplitz.toeplitz_matmul(col.unsqueeze(0),
                                             row.unsqueeze(0), rhs_mat)
        self.assertTrue(utils.approx_equal(res, actual))
Example #9
0
    def test_getitem_batch(self):
        block_var = Variable(blocks, requires_grad=True)
        actual_block_diagonal = Variable(torch.zeros(2, 16, 16))
        for i in range(2):
            for j in range(4):
                actual_block_diagonal[i, j * 4:(j + 1) * 4,
                                      j * 4:(j + 1) * 4] = block_var[i * 4 + j]

        res = BlockDiagonalLazyVariable(NonLazyVariable(block_var),
                                        n_blocks=4)[0].evaluate()
        actual = actual_block_diagonal[0]
        self.assertTrue(approx_equal(actual.data, res.data))

        res = BlockDiagonalLazyVariable(NonLazyVariable(block_var),
                                        n_blocks=4)[0, :5].evaluate()
        actual = actual_block_diagonal[0, :5]
        self.assertTrue(approx_equal(actual.data, res.data))

        res = BlockDiagonalLazyVariable(NonLazyVariable(block_var),
                                        n_blocks=4)[1:, :5, 2]
        actual = actual_block_diagonal[1:, :5, 2]
        self.assertTrue(approx_equal(actual.data, res.data))
Example #10
0
    def test_inv_quad_only_many_vectors(self):
        # Forward pass
        res = NonLazyVariable(self.mats_var).inv_quad(self.vecs_var)
        actual = torch.cat([
            self.mats_var_clone[0].inverse().unsqueeze(0),
            self.mats_var_clone[1].inverse().unsqueeze(0),
        ]).matmul(self.vecs_var_clone).mul(self.vecs_var_clone).sum(2).sum(1)
        for i in range(self.mats_var.size(0)):
            self.assertAlmostEqual(res.data[i], actual.data[i], places=1)

        # Backward
        inv_quad_grad_output = torch.randn(2)
        actual.backward(gradient=inv_quad_grad_output)
        res.backward(gradient=inv_quad_grad_output)

        self.assertTrue(
            approx_equal(self.mats_var_clone.grad.data,
                         self.mats_var.grad.data,
                         epsilon=1e-1))
        self.assertTrue(
            approx_equal(self.vecs_var_clone.grad.data,
                         self.vecs_var.grad.data))
Example #11
0
    def test_batch_left_t_interp_on_a_matrix(self):
        batch_matrix = torch.randn(9, 3)

        res = left_t_interp(
            self.batch_interp_indices,
            self.batch_interp_values,
            Variable(batch_matrix),
            6,
        ).data
        actual = torch.matmul(
            self.batch_interp_matrix.transpose(-1, -2), batch_matrix.unsqueeze(0)
        )
        self.assertTrue(approx_equal(res, actual))
Example #12
0
def test_toeplitz_matmul_batchmat():
    col = torch.Tensor([1, 6, 4, 5])
    row = torch.Tensor([1, 2, 1, 1])
    rhs_mat = torch.randn(3, 4, 2)

    # Actual
    lhs_mat = utils.toeplitz.toeplitz(col, row)
    actual = torch.matmul(lhs_mat.unsqueeze(0), rhs_mat)

    # Fast toeplitz
    res = utils.toeplitz.toeplitz_matmul(col.unsqueeze(0), row.unsqueeze(0),
                                         rhs_mat)
    assert utils.approx_equal(res, actual)
    def test_get_indices(self):
        root = torch.randn(5, 3)
        actual = root.matmul(root.transpose(-1, -2))
        res = RootLazyTensor(root)

        left_indices = torch.tensor([1, 2, 4, 0], dtype=torch.long)
        right_indices = torch.tensor([0, 1, 3, 2], dtype=torch.long)

        self.assertTrue(
            approx_equal(actual[left_indices, right_indices],
                         res._get_indices(left_indices, right_indices)))

        left_indices = torch.tensor(
            [1, 2, 4, 0, 1, 2, 3, 1, 2, 2, 1, 1, 0, 0, 4, 4, 4, 4],
            dtype=torch.long)
        right_indices = torch.tensor(
            [0, 1, 3, 2, 3, 4, 2, 2, 1, 1, 2, 1, 2, 4, 4, 3, 3, 0],
            dtype=torch.long)

        self.assertTrue(
            approx_equal(actual[left_indices, right_indices],
                         res._get_indices(left_indices, right_indices)))
    def test_interpolation(self):
        x = torch.linspace(0.01, 1, 100).unsqueeze(1)
        grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(0)
        indices, values = Interpolation().interpolate(grid, x)
        indices = indices.squeeze_(0)
        values = values.squeeze_(0)
        test_func_grid = grid.squeeze(0).pow(2)
        test_func_x = x.pow(2).squeeze(-1)

        interp_func_x = utils.left_interp(
            indices, values, test_func_grid.unsqueeze(1)).squeeze()

        self.assertTrue(utils.approx_equal(interp_func_x, test_func_x))
Example #15
0
    def test_diag(self):
        avar = Variable(a)
        bvar = Variable(b)
        cvar = Variable(c)
        kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(avar),
                                                   NonLazyVariable(bvar),
                                                   NonLazyVariable(cvar))
        res = kp_lazy_var.diag()
        actual = kron(kron(avar, bvar), cvar).diag()
        self.assertTrue(approx_equal(res.data, actual.data))

        avar = Variable(a.repeat(3, 1, 1))
        bvar = Variable(b.repeat(3, 1, 1))
        cvar = Variable(c.repeat(3, 1, 1))
        kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(avar),
                                                   NonLazyVariable(bvar),
                                                   NonLazyVariable(cvar))
        res = kp_lazy_var.diag()
        actual_mat = kron(kron(avar, bvar), cvar)
        actual = torch.stack(
            [actual_mat[0].diag(), actual_mat[1].diag(), actual_mat[2].diag()])
        self.assertTrue(approx_equal(res.data, actual.data))
    def test_matmul_batch(self):
        left_interp_indices = Variable(
            torch.LongTensor([[2, 3], [3, 4], [4, 5]])).repeat(5, 3, 1)
        left_interp_values = Variable(torch.Tensor([[1, 2], [0.5, 1],
                                                    [1, 3]])).repeat(5, 3, 1)
        right_interp_indices = Variable(
            torch.LongTensor([[0, 1], [1, 2], [2, 3]])).repeat(5, 3, 1)
        right_interp_values = Variable(torch.Tensor([[1, 2], [2, 0.5],
                                                     [1, 3]])).repeat(5, 3, 1)

        base_lazy_variable_mat = torch.randn(5, 6, 6)
        base_lazy_variable_mat = base_lazy_variable_mat.transpose(
            1, 2).matmul(base_lazy_variable_mat)
        test_matrix = Variable(torch.randn(1, 9, 4))

        base_lazy_variable = NonLazyVariable(
            Variable(base_lazy_variable_mat, requires_grad=True))
        interp_lazy_var = InterpolatedLazyVariable(base_lazy_variable,
                                                   left_interp_indices,
                                                   left_interp_values,
                                                   right_interp_indices,
                                                   right_interp_values)
        res = interp_lazy_var.matmul(test_matrix)

        left_matrix = torch.Tensor([
            [0, 0, 1, 2, 0, 0],
            [0, 0, 0, 0.5, 1, 0],
            [0, 0, 0, 0, 1, 3],
            [0, 0, 1, 2, 0, 0],
            [0, 0, 0, 0.5, 1, 0],
            [0, 0, 0, 0, 1, 3],
            [0, 0, 1, 2, 0, 0],
            [0, 0, 0, 0.5, 1, 0],
            [0, 0, 0, 0, 1, 3],
        ]).repeat(5, 1, 1)

        right_matrix = torch.Tensor([
            [1, 2, 0, 0, 0, 0],
            [0, 2, 0.5, 0, 0, 0],
            [0, 0, 1, 3, 0, 0],
            [1, 2, 0, 0, 0, 0],
            [0, 2, 0.5, 0, 0, 0],
            [0, 0, 1, 3, 0, 0],
            [1, 2, 0, 0, 0, 0],
            [0, 2, 0.5, 0, 0, 0],
            [0, 0, 1, 3, 0, 0],
        ]).repeat(5, 1, 1)
        actual = (left_matrix.matmul(base_lazy_variable_mat).matmul(
            right_matrix.transpose(-1, -2)).matmul(test_matrix.data))

        self.assertTrue(approx_equal(res.data, actual))
Example #17
0
    def test_pivoted_cholesky(self):
        size = 100
        train_x = torch.cat(
            [
                torch.linspace(0, 1, size).unsqueeze(0),
                torch.linspace(0, 0.5, size).unsqueeze(0),
            ],
            0,
        ).unsqueeze(-1)
        covar_matrix = RBFKernel()(train_x, train_x)
        piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10)
        covar_approx = piv_chol.transpose(1, 2).matmul(piv_chol)

        self.assertTrue(approx_equal(covar_approx, covar_matrix, 2e-4))
Example #18
0
    def test_rotate_matrix_reverse(self):
        a = torch.randn(5, 5)
        Q0 = torch.zeros(5, 5)
        Q0[0, 4] = 1
        Q0[1:, :-1] = torch.eye(4)

        Q = Q0.clone()
        for i in range(1, 5):
            a_rotated_result = circulant.rotate(a, -i)
            a_rotated_actual = Q.inverse().matmul(a)

            self.assertTrue(
                utils.approx_equal(a_rotated_actual, a_rotated_result))
            Q = Q.matmul(Q0)
Example #19
0
    def test_matmul(self):
        rhs_tensor = torch.randn(4 * 8, 4, requires_grad=True)
        rhs_tensor_copy = rhs_tensor.clone().detach().requires_grad_(True)
        block_tensor = self.blocks.clone().requires_grad_(True)
        block_tensor_copy = self.blocks.clone().requires_grad_(True)

        actual_block_diag = torch.zeros(32, 32)
        for i in range(8):
            actual_block_diag[i * 4:(i + 1) * 4,
                              i * 4:(i + 1) * 4] = block_tensor_copy[i]

        res = BlockDiagLazyTensor(
            NonLazyTensor(block_tensor)).matmul(rhs_tensor)
        actual = actual_block_diag.matmul(rhs_tensor_copy)

        self.assertTrue(approx_equal(res, actual))

        actual.sum().backward()
        res.sum().backward()

        self.assertTrue(approx_equal(rhs_tensor.grad, rhs_tensor_copy.grad))
        self.assertTrue(approx_equal(block_tensor.grad,
                                     block_tensor_copy.grad))
def test_matmul_batch_mat():
    avar = Variable(a.repeat(3, 1, 1), requires_grad=True)
    bvar = Variable(b.repeat(3, 1, 1), requires_grad=True)
    cvar = Variable(c.repeat(3, 1, 1), requires_grad=True)
    mat = Variable(torch.randn(3, 24, 5), requires_grad=True)
    kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(avar),
                                               NonLazyVariable(bvar),
                                               NonLazyVariable(cvar))
    res = kp_lazy_var.matmul(mat)

    avar_copy = Variable(a.repeat(3, 1, 1), requires_grad=True)
    bvar_copy = Variable(b.repeat(3, 1, 1), requires_grad=True)
    cvar_copy = Variable(c.repeat(3, 1, 1), requires_grad=True)
    mat_copy = Variable(mat.data.clone(), requires_grad=True)
    actual = kron(kron(avar_copy, bvar_copy), cvar_copy).matmul(mat_copy)
    assert approx_equal(res.data, actual.data)

    actual.sum().backward()
    res.sum().backward()
    assert approx_equal(avar_copy.grad.data, avar.grad.data)
    assert approx_equal(bvar_copy.grad.data, bvar.grad.data)
    assert approx_equal(cvar_copy.grad.data, cvar.grad.data)
    assert approx_equal(mat_copy.grad.data, mat.grad.data)
    def test_matmul_batch_mat(self):
        avar = a.repeat(3, 1, 1).requires_grad_(True)
        bvar = b.repeat(3, 1, 1).requires_grad_(True)
        cvar = c.repeat(3, 1, 1).requires_grad_(True)
        mat = torch.randn(3, 24, 5, requires_grad=True)
        kp_lazy_var = KroneckerProductLazyTensor(NonLazyTensor(avar),
                                                 NonLazyTensor(bvar),
                                                 NonLazyTensor(cvar))
        res = kp_lazy_var.matmul(mat)

        avar_copy = avar.clone().detach().requires_grad_(True)
        bvar_copy = bvar.clone().detach().requires_grad_(True)
        cvar_copy = cvar.clone().detach().requires_grad_(True)
        mat_copy = mat.clone().detach().requires_grad_(True)
        actual = kron(kron(avar_copy, bvar_copy), cvar_copy).matmul(mat_copy)
        self.assertTrue(approx_equal(res, actual))

        actual.sum().backward()
        res.sum().backward()
        self.assertTrue(approx_equal(avar_copy.grad, avar.grad))
        self.assertTrue(approx_equal(bvar_copy.grad, bvar.grad))
        self.assertTrue(approx_equal(cvar_copy.grad, cvar.grad))
        self.assertTrue(approx_equal(mat_copy.grad, mat.grad))
def test_matmul():
    rhs = torch.randn(4 * 8, 4)
    rhs_var = Variable(rhs, requires_grad=True)
    rhs_var_copy = Variable(rhs, requires_grad=True)

    block_var = Variable(blocks, requires_grad=True)
    block_var_copy = Variable(blocks, requires_grad=True)

    actual_block_diagonal = Variable(torch.zeros(32, 32))
    for i in range(8):
        actual_block_diagonal[i * 4:(i + 1) * 4,
                              i * 4:(i + 1) * 4] = block_var_copy[i]

    res = BlockDiagonalLazyVariable(NonLazyVariable(block_var)).matmul(rhs_var)
    actual = actual_block_diagonal.matmul(rhs_var_copy)

    assert approx_equal(res.data, actual.data)

    actual.sum().backward()
    res.sum().backward()

    assert approx_equal(rhs_var.grad.data, rhs_var_copy.grad.data)
    assert approx_equal(block_var.grad.data, block_var_copy.grad.data)
Example #23
0
    def test_batch_matmul(self):
        rhs = torch.randn(2, 4 * 4, 4)
        rhs_var = Variable(rhs, requires_grad=True)
        rhs_var_copy = Variable(rhs, requires_grad=True)

        block_var = Variable(blocks, requires_grad=True)
        block_var_copy = Variable(blocks, requires_grad=True)

        actual_block_diagonal = Variable(torch.zeros(2, 16, 16))
        for i in range(2):
            for j in range(4):
                actual_block_diagonal[i, j * 4 : (j + 1) * 4, j * 4 : (j + 1) * 4] = block_var_copy[i * 4 + j]

        res = BlockDiagonalLazyVariable(NonLazyVariable(block_var), n_blocks=4).matmul(rhs_var)
        actual = actual_block_diagonal.matmul(rhs_var_copy)

        self.assertTrue(approx_equal(res.data, actual.data))

        actual.sum().backward()
        res.sum().backward()

        self.assertTrue(approx_equal(rhs_var.grad.data, rhs_var_copy.grad.data))
        self.assertTrue(approx_equal(block_var.grad.data, block_var_copy.grad.data))
Example #24
0
    def test_batch_diag(self):
        root = Variable(torch.randn(4, 5, 3))
        actual = root.matmul(root.transpose(-1, -2))
        actual_diag = torch.cat(
            [
                actual[0].diag().unsqueeze(0),
                actual[1].diag().unsqueeze(0),
                actual[2].diag().unsqueeze(0),
                actual[3].diag().unsqueeze(0),
            ]
        )

        res = RootLazyVariable(root)
        self.assertTrue(approx_equal(actual_diag.data, res.diag().data))
Example #25
0
    def test_potrs(self):
        chol = torch.Tensor(
            [[1, 0, 0, 0], [2, 1, 0, 0], [0, 1, 2, 0], [0, 0, 2, 3]]
        ).unsqueeze(
            0
        )

        mat = torch.randn(1, 4, 3)
        self.assertTrue(
            approx_equal(
                torch.potrs(mat[0], chol[0], upper=False),
                tridiag_batch_potrs(mat, chol, upper=False)[0],
            )
        )
Example #26
0
    def test_cg(self):
        size = 100
        matrix = torch.randn(size, size, dtype=torch.float64)
        matrix = matrix.matmul(matrix.transpose(-1, -2))
        matrix.div_(matrix.norm())
        matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))

        rhs = torch.randn(size, 50, dtype=torch.float64)
        solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size)

        # Check cg
        matrix_chol = matrix.potrf()
        actual = torch.potrs(rhs, matrix_chol)
        self.assertTrue(approx_equal(solves, actual))
Example #27
0
    def test_get_indices(self):
        lhs = torch.randn(5, 1)
        rhs = torch.randn(1, 5)
        actual = lhs.matmul(rhs)
        res = MatmulLazyTensor(lhs, rhs)

        left_indices = torch.tensor([1, 2, 4, 0], dtype=torch.long)
        right_indices = torch.tensor([0, 1, 3, 2], dtype=torch.long)

        self.assertTrue(
            approx_equal(actual[left_indices, right_indices],
                         res._get_indices(left_indices, right_indices)))

        left_indices = torch.tensor(
            [1, 2, 4, 0, 1, 2, 3, 1, 2, 2, 1, 1, 0, 0, 4, 4, 4, 4],
            dtype=torch.long)
        right_indices = torch.tensor(
            [0, 1, 3, 2, 3, 4, 2, 2, 1, 1, 2, 1, 2, 4, 4, 3, 3, 0],
            dtype=torch.long)

        self.assertTrue(
            approx_equal(actual[left_indices, right_indices],
                         res._get_indices(left_indices, right_indices)))
    def test_log_det_only(self):
        # Forward pass
        with gpytorch.settings.num_trace_samples(1000):
            res = NonLazyTensor(self.mat_var).log_det()
        actual = self.mat_var_clone.det().log()
        self.assertAlmostEqual(res.item(), actual.item(), places=1)

        # Backward
        actual.backward()
        res.backward()
        self.assertTrue(
            approx_equal(self.mat_var_clone.grad,
                         self.mat_var.grad,
                         epsilon=1e-1))
Example #29
0
    def test_cg(self):
        size = 100
        matrix = torch.DoubleTensor(size, size).normal_()
        matrix = matrix.matmul(matrix.transpose(-1, -2))
        matrix.div_(matrix.norm())
        matrix.add_(torch.DoubleTensor(matrix.size(-1)).fill_(1e-1).diag())

        rhs = torch.DoubleTensor(size, 50).normal_()
        solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size)

        # Check cg
        matrix_chol = matrix.potrf()
        actual = torch.potrs(rhs, matrix_chol)
        self.assertTrue(approx_equal(solves, actual))
Example #30
0
    def test_lanczos(self):
        size = 100
        matrix = torch.randn(size, size)
        matrix = matrix.matmul(matrix.transpose(-1, -2))
        matrix.div_(matrix.norm())
        matrix.add_(torch.ones(matrix.size(-1)).mul(1e-6).diag())
        q_mat, t_mat = lanczos_tridiag(matrix.matmul,
                                       max_iter=size,
                                       dtype=matrix.dtype,
                                       device=matrix.device,
                                       n_dims=matrix.size(-1))

        approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2))
        self.assertTrue(approx_equal(approx, matrix))