def test_kronecker_product_mul():
    kronecker_matrices = []
    kronecker_matrices.append(torch.randn(3, 3))
    kronecker_matrices.append(torch.randn(2, 2))
    kronecker_matrices.append(torch.randn(3, 3))

    matrix = torch.randn(3 * 2 * 3, 9)
    res = kronecker_product_mul(kronecker_matrices, matrix)

    kronecker_product_matrix = kronecker_product(kronecker_matrices)
    actual = kronecker_product_matrix.mm(matrix)
    assert (torch.norm(res - actual) < 1e-4)
示例#2
0
def test_kronecker_product_toeplitz_matmul():
    toeplitz_columns = torch.randn(3, 3)
    matrix = torch.randn(27, 10)
    res = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns,
                                            matrix)

    toeplitz_matrices = torch.zeros(3, 3, 3)
    for i in range(3):
        toeplitz_matrices[i] = toeplitz(toeplitz_columns[i],
                                        toeplitz_columns[i])

    kronecker_product_matrix = kronecker_product(toeplitz_matrices)
    actual = kronecker_product_matrix.mm(matrix)

    assert (torch.norm(res - actual) < 1e-4)
示例#3
0
def test_kronecker_product():
    matrix_list = []
    matrix1 = torch.Tensor([
        [1, 2, 3],
        [4, 5, 6],
    ])
    matrix2 = torch.Tensor([
        [1, 2],
        [4, 3],
    ])
    matrix_list.append(matrix1)
    matrix_list.append(matrix2)
    res = kronecker_product(matrix_list)

    actual = torch.Tensor([[1, 2, 2, 4, 3, 6], [4, 3, 8, 6, 12, 9],
                           [4, 8, 5, 10, 6, 12], [16, 12, 20, 15, 24, 18]])

    assert (torch.equal(res, actual))
        mean_x = self.mean_module(x)
        covar_x = self.grid_covar_module(x)
        return GaussianRandomVariable(mean_x, covar_x)


prior_observation_model = Model()
prior_observation_model.eval()
pred = prior_observation_model(x)
lazy_kronecker_product_var = pred.covar()
Ts = torch.zeros(lazy_kronecker_product_var.columns.size()[0],
                 lazy_kronecker_product_var.columns.size()[1],
                 lazy_kronecker_product_var.columns.size()[1])
for i in range(lazy_kronecker_product_var.columns.size()[0]):
    Ts[i] = utils.toeplitz.sym_toeplitz(
        lazy_kronecker_product_var.columns[i].data)
K = kronecker_product(Ts)
W_left = list_of_indices_and_values_to_sparse(
    lazy_kronecker_product_var.J_lefts, lazy_kronecker_product_var.C_lefts,
    lazy_kronecker_product_var.columns)
W_right = list_of_indices_and_values_to_sparse(
    lazy_kronecker_product_var.J_rights, lazy_kronecker_product_var.C_rights,
    lazy_kronecker_product_var.columns)
WKW = torch.dsmm(W_right,
                 torch.dsmm(W_left, K).t()) + torch.diag(
                     lazy_kronecker_product_var.added_diag.data)


def test_evaluate():
    WKW_res = lazy_kronecker_product_var.evaluate()
    assert utils.approx_equal(WKW_res, WKW)
def foo_kp_toeplitz_gp_marginal_log_likelihood_backward():
    x = torch.cat([Variable(torch.linspace(0, 1, 2)).unsqueeze(1)] * 3, 1)
    y = Variable(torch.randn(2), requires_grad=True)
    rbf_module = RBFKernel()
    rbf_module.initialize(log_lengthscale=-2)
    covar_module = GridInterpolationKernel(rbf_module)
    covar_module.eval()
    covar_module.initialize_interpolation_grid(5, [(0, 1), (0, 1), (0, 1)])

    kronecker_var = covar_module.forward(x, x)

    cs = Variable(torch.zeros(3, 5), requires_grad=True)
    J_lefts = []
    C_lefts = []
    J_rights = []
    C_rights = []
    Ts = []
    for i in range(3):
        covar_x = covar_module.forward(x[:, i].unsqueeze(1), x[:,
                                                               i].unsqueeze(1))
        cs.data[i] = covar_x.c.data
        J_lefts.append(covar_x.J_left)
        C_lefts.append(covar_x.C_left)
        J_rights.append(covar_x.J_right)
        C_rights.append(covar_x.C_right)
        T = Variable(torch.zeros(len(cs[i].data), len(cs[i].data)))
        for k in range(len(cs[i].data)):
            for j in range(len(cs[i].data)):
                T[k, j] = utils.toeplitz.toeplitz_getitem(cs[i], cs[i], k, j)
        Ts.append(T)

    W_left = list_of_indices_and_values_to_sparse(J_lefts, C_lefts, cs)
    W_right = list_of_indices_and_values_to_sparse(J_rights, C_rights, cs)
    W_left_dense = Variable(W_left.to_dense())
    W_right_dense = Variable(W_right.to_dense())
    K = kronecker_product(Ts)
    WKW = W_left_dense.matmul(K.matmul(W_right_dense.t()))
    quad_form_actual = y.dot(WKW.inverse().matmul(y))
    log_det_actual = _det(WKW).log()

    actual_nll = -0.5 * (log_det_actual + quad_form_actual +
                         math.log(2 * math.pi) * len(y))
    actual_nll.backward()
    actual_cs_grad = cs.grad.data.clone()
    actual_y_grad = y.grad.data.clone()

    y.grad.data.fill_(0)
    cs.grad.data.fill_(0)

    kronecker_var = gpytorch.lazy.kroneckerProductLazyVariable(
        cs, kronecker_var.J_lefts, kronecker_var.C_lefts,
        kronecker_var.J_rights, kronecker_var.C_rights)
    gpytorch.functions.num_trace_samples = 100
    res = kronecker_var.exact_gp_marginal_log_likelihood(y)
    res.backward()

    res_cs_grad = covar_x.cs.grad.data
    res_y_grad = y.grad.data

    assert (actual_cs_grad - res_cs_grad).norm() / res_cs_grad.norm() < 0.05
    assert (actual_y_grad - res_y_grad).norm() / res_y_grad.norm() < 1e-3

    y.grad.data.fill_(0)
    cs.grad.data.fill_(0)

    gpytorch.functions.fastest = False
    res = kronecker_var.exact_gp_marginal_log_likelihood(y)
    res.backward()

    res_cs_grad = covar_x.cs.grad.data
    res_y_grad = y.grad.data

    assert (actual_cs_grad - res_cs_grad).norm() / res_cs_grad.norm() < 1e-3
    assert (actual_y_grad - res_y_grad).norm() / res_y_grad.norm() < 1e-3