def test_kronecker_product_mul(): kronecker_matrices = [] kronecker_matrices.append(torch.randn(3, 3)) kronecker_matrices.append(torch.randn(2, 2)) kronecker_matrices.append(torch.randn(3, 3)) matrix = torch.randn(3 * 2 * 3, 9) res = kronecker_product_mul(kronecker_matrices, matrix) kronecker_product_matrix = kronecker_product(kronecker_matrices) actual = kronecker_product_matrix.mm(matrix) assert (torch.norm(res - actual) < 1e-4)
def test_kronecker_product_toeplitz_matmul(): toeplitz_columns = torch.randn(3, 3) matrix = torch.randn(27, 10) res = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, matrix) toeplitz_matrices = torch.zeros(3, 3, 3) for i in range(3): toeplitz_matrices[i] = toeplitz(toeplitz_columns[i], toeplitz_columns[i]) kronecker_product_matrix = kronecker_product(toeplitz_matrices) actual = kronecker_product_matrix.mm(matrix) assert (torch.norm(res - actual) < 1e-4)
def test_kronecker_product(): matrix_list = [] matrix1 = torch.Tensor([ [1, 2, 3], [4, 5, 6], ]) matrix2 = torch.Tensor([ [1, 2], [4, 3], ]) matrix_list.append(matrix1) matrix_list.append(matrix2) res = kronecker_product(matrix_list) actual = torch.Tensor([[1, 2, 2, 4, 3, 6], [4, 3, 8, 6, 12, 9], [4, 8, 5, 10, 6, 12], [16, 12, 20, 15, 24, 18]]) assert (torch.equal(res, actual))
mean_x = self.mean_module(x) covar_x = self.grid_covar_module(x) return GaussianRandomVariable(mean_x, covar_x) prior_observation_model = Model() prior_observation_model.eval() pred = prior_observation_model(x) lazy_kronecker_product_var = pred.covar() Ts = torch.zeros(lazy_kronecker_product_var.columns.size()[0], lazy_kronecker_product_var.columns.size()[1], lazy_kronecker_product_var.columns.size()[1]) for i in range(lazy_kronecker_product_var.columns.size()[0]): Ts[i] = utils.toeplitz.sym_toeplitz( lazy_kronecker_product_var.columns[i].data) K = kronecker_product(Ts) W_left = list_of_indices_and_values_to_sparse( lazy_kronecker_product_var.J_lefts, lazy_kronecker_product_var.C_lefts, lazy_kronecker_product_var.columns) W_right = list_of_indices_and_values_to_sparse( lazy_kronecker_product_var.J_rights, lazy_kronecker_product_var.C_rights, lazy_kronecker_product_var.columns) WKW = torch.dsmm(W_right, torch.dsmm(W_left, K).t()) + torch.diag( lazy_kronecker_product_var.added_diag.data) def test_evaluate(): WKW_res = lazy_kronecker_product_var.evaluate() assert utils.approx_equal(WKW_res, WKW)
def foo_kp_toeplitz_gp_marginal_log_likelihood_backward(): x = torch.cat([Variable(torch.linspace(0, 1, 2)).unsqueeze(1)] * 3, 1) y = Variable(torch.randn(2), requires_grad=True) rbf_module = RBFKernel() rbf_module.initialize(log_lengthscale=-2) covar_module = GridInterpolationKernel(rbf_module) covar_module.eval() covar_module.initialize_interpolation_grid(5, [(0, 1), (0, 1), (0, 1)]) kronecker_var = covar_module.forward(x, x) cs = Variable(torch.zeros(3, 5), requires_grad=True) J_lefts = [] C_lefts = [] J_rights = [] C_rights = [] Ts = [] for i in range(3): covar_x = covar_module.forward(x[:, i].unsqueeze(1), x[:, i].unsqueeze(1)) cs.data[i] = covar_x.c.data J_lefts.append(covar_x.J_left) C_lefts.append(covar_x.C_left) J_rights.append(covar_x.J_right) C_rights.append(covar_x.C_right) T = Variable(torch.zeros(len(cs[i].data), len(cs[i].data))) for k in range(len(cs[i].data)): for j in range(len(cs[i].data)): T[k, j] = utils.toeplitz.toeplitz_getitem(cs[i], cs[i], k, j) Ts.append(T) W_left = list_of_indices_and_values_to_sparse(J_lefts, C_lefts, cs) W_right = list_of_indices_and_values_to_sparse(J_rights, C_rights, cs) W_left_dense = Variable(W_left.to_dense()) W_right_dense = Variable(W_right.to_dense()) K = kronecker_product(Ts) WKW = W_left_dense.matmul(K.matmul(W_right_dense.t())) quad_form_actual = y.dot(WKW.inverse().matmul(y)) log_det_actual = _det(WKW).log() actual_nll = -0.5 * (log_det_actual + quad_form_actual + math.log(2 * math.pi) * len(y)) actual_nll.backward() actual_cs_grad = cs.grad.data.clone() actual_y_grad = y.grad.data.clone() y.grad.data.fill_(0) cs.grad.data.fill_(0) kronecker_var = gpytorch.lazy.kroneckerProductLazyVariable( cs, kronecker_var.J_lefts, kronecker_var.C_lefts, kronecker_var.J_rights, kronecker_var.C_rights) gpytorch.functions.num_trace_samples = 100 res = kronecker_var.exact_gp_marginal_log_likelihood(y) res.backward() res_cs_grad = covar_x.cs.grad.data res_y_grad = y.grad.data assert (actual_cs_grad - res_cs_grad).norm() / res_cs_grad.norm() < 0.05 assert (actual_y_grad - res_y_grad).norm() / res_y_grad.norm() < 1e-3 y.grad.data.fill_(0) cs.grad.data.fill_(0) gpytorch.functions.fastest = False res = kronecker_var.exact_gp_marginal_log_likelihood(y) res.backward() res_cs_grad = covar_x.cs.grad.data res_y_grad = y.grad.data assert (actual_cs_grad - res_cs_grad).norm() / res_cs_grad.norm() < 1e-3 assert (actual_y_grad - res_y_grad).norm() / res_y_grad.norm() < 1e-3