def test_interpolated_toeplitz_gp_marginal_log_likelihood_forward(): x = Variable(torch.linspace(0, 1, 5)) y = torch.randn(5) rbf_covar = RBFKernel() rbf_covar.initialize(log_lengthscale=-4) covar_module = GridInterpolationKernel(rbf_covar) covar_module.eval() covar_module.initialize_interpolation_grid(10, [(0, 1)]) covar_x = covar_module.forward(x.unsqueeze(1), x.unsqueeze(1)) c = covar_x.c.data T = utils.toeplitz.sym_toeplitz(c) W_left = index_coef_to_sparse(covar_x.J_left, covar_x.C_left, len(c)) W_right = index_coef_to_sparse(covar_x.J_right, covar_x.C_right, len(c)) W_left_dense = W_left.to_dense() W_right_dense = W_right.to_dense() WTW = W_left_dense.matmul(T.matmul( W_right_dense.t())) + torch.eye(len(x)) * 1e-4 quad_form_actual = y.dot(WTW.inverse().matmul(y)) chol_T = torch.potrf(WTW) log_det_actual = chol_T.diag().log().sum() * 2 actual = -0.5 * (log_det_actual + quad_form_actual + math.log(2 * math.pi) * len(y)) res = covar_x.exact_gp_marginal_log_likelihood(Variable(y)) assert all(torch.abs((res.data - actual) / actual) < 0.05)
def test_interpolated_toeplitz_gp_marginal_log_likelihood_backward(): x = Variable(torch.linspace(0, 1, 5)) y = Variable(torch.randn(5), requires_grad=True) noise = Variable(torch.Tensor([1e-4]), requires_grad=True) rbf_covar = RBFKernel() rbf_covar.initialize(log_lengthscale=-4) covar_module = GridInterpolationKernel(rbf_covar) covar_module.eval() covar_module.initialize_interpolation_grid(10, [(0, 1)]) covar_x = covar_module.forward(x.unsqueeze(1), x.unsqueeze(1)) c = Variable(covar_x.c.data, requires_grad=True) W_left = index_coef_to_sparse(covar_x.J_left, covar_x.C_left, len(c)) W_right = index_coef_to_sparse(covar_x.J_right, covar_x.C_right, len(c)) W_left_dense = Variable(W_left.to_dense()) W_right_dense = Variable(W_right.to_dense()) T = Variable(torch.zeros(len(c), len(c))) for i in range(len(c)): for j in range(len(c)): T[i, j] = utils.toeplitz.sym_toeplitz_getitem(c, i, j) WTW = W_left_dense.matmul(T.matmul( W_right_dense.t())) + Variable(torch.eye(len(x))) * noise quad_form_actual = y.dot(WTW.inverse().matmul(y)) log_det_actual = _det(WTW).log() actual_nll = -0.5 * (log_det_actual + quad_form_actual + math.log(2 * math.pi) * len(y)) actual_nll.backward() actual_c_grad = c.grad.data.clone() actual_y_grad = y.grad.data.clone() actual_noise_grad = noise.grad.data.clone() c.grad.data.fill_(0) y.grad.data.fill_(0) noise.grad.data.fill_(0) covar_x = gpytorch.lazy.ToeplitzLazyVariable(c, covar_x.J_left, covar_x.C_left, covar_x.J_right, covar_x.C_right, noise) res = covar_x.exact_gp_marginal_log_likelihood(y) res.backward() res_c_grad = covar_x.c.grad.data res_y_grad = y.grad.data res_noise_grad = noise.grad.data assert (actual_c_grad - res_c_grad).norm() / res_c_grad.norm() < 0.05 assert (actual_y_grad - res_y_grad).norm() / res_y_grad.norm() < 1e-3 assert (actual_noise_grad - res_noise_grad).norm() / res_noise_grad.norm() < 1e-3
def test_kp_toeplitz_gp_marginal_log_likelihood_forward(): x = torch.cat([Variable(torch.linspace(0, 1, 2)).unsqueeze(1)] * 3, 1) y = torch.randn(2) rbf_module = RBFKernel() rbf_module.initialize(log_lengthscale=-2) covar_module = GridInterpolationKernel(rbf_module) covar_module.eval() covar_module.initialize_interpolation_grid(5, [(0, 1), (0, 1), (0, 1)]) kronecker_var = covar_module.forward(x, x) kronecker_var_eval = kronecker_var.evaluate() res = kronecker_var.exact_gp_marginal_log_likelihood(Variable(y)).data actual = gpytorch.exact_gp_marginal_log_likelihood(kronecker_var_eval, Variable(y)).data assert all(torch.abs((res - actual) / actual) < 0.05)
def foo_kp_toeplitz_gp_marginal_log_likelihood_backward(): x = torch.cat([Variable(torch.linspace(0, 1, 2)).unsqueeze(1)] * 3, 1) y = Variable(torch.randn(2), requires_grad=True) rbf_module = RBFKernel() rbf_module.initialize(log_lengthscale=-2) covar_module = GridInterpolationKernel(rbf_module) covar_module.eval() covar_module.initialize_interpolation_grid(5, [(0, 1), (0, 1), (0, 1)]) kronecker_var = covar_module.forward(x, x) cs = Variable(torch.zeros(3, 5), requires_grad=True) J_lefts = [] C_lefts = [] J_rights = [] C_rights = [] Ts = [] for i in range(3): covar_x = covar_module.forward(x[:, i].unsqueeze(1), x[:, i].unsqueeze(1)) cs.data[i] = covar_x.c.data J_lefts.append(covar_x.J_left) C_lefts.append(covar_x.C_left) J_rights.append(covar_x.J_right) C_rights.append(covar_x.C_right) T = Variable(torch.zeros(len(cs[i].data), len(cs[i].data))) for k in range(len(cs[i].data)): for j in range(len(cs[i].data)): T[k, j] = utils.toeplitz.toeplitz_getitem(cs[i], cs[i], k, j) Ts.append(T) W_left = list_of_indices_and_values_to_sparse(J_lefts, C_lefts, cs) W_right = list_of_indices_and_values_to_sparse(J_rights, C_rights, cs) W_left_dense = Variable(W_left.to_dense()) W_right_dense = Variable(W_right.to_dense()) K = kronecker_product(Ts) WKW = W_left_dense.matmul(K.matmul(W_right_dense.t())) quad_form_actual = y.dot(WKW.inverse().matmul(y)) log_det_actual = _det(WKW).log() actual_nll = -0.5 * (log_det_actual + quad_form_actual + math.log(2 * math.pi) * len(y)) actual_nll.backward() actual_cs_grad = cs.grad.data.clone() actual_y_grad = y.grad.data.clone() y.grad.data.fill_(0) cs.grad.data.fill_(0) kronecker_var = gpytorch.lazy.kroneckerProductLazyVariable( cs, kronecker_var.J_lefts, kronecker_var.C_lefts, kronecker_var.J_rights, kronecker_var.C_rights) gpytorch.functions.num_trace_samples = 100 res = kronecker_var.exact_gp_marginal_log_likelihood(y) res.backward() res_cs_grad = covar_x.cs.grad.data res_y_grad = y.grad.data assert (actual_cs_grad - res_cs_grad).norm() / res_cs_grad.norm() < 0.05 assert (actual_y_grad - res_y_grad).norm() / res_y_grad.norm() < 1e-3 y.grad.data.fill_(0) cs.grad.data.fill_(0) gpytorch.functions.fastest = False res = kronecker_var.exact_gp_marginal_log_likelihood(y) res.backward() res_cs_grad = covar_x.cs.grad.data res_y_grad = y.grad.data assert (actual_cs_grad - res_cs_grad).norm() / res_cs_grad.norm() < 1e-3 assert (actual_y_grad - res_y_grad).norm() / res_y_grad.norm() < 1e-3
def test_trace_logdet_quad_form_factory(): x = Variable(torch.linspace(0, 1, 10)) rbf_covar = RBFKernel() rbf_covar.initialize(log_lengthscale=-4) covar_module = GridInterpolationKernel(rbf_covar) covar_module.eval() covar_module.initialize_interpolation_grid(4, [(0, 1)]) c = Variable(covar_module.forward(x.unsqueeze(1), x.unsqueeze(1)).c.data, requires_grad=True) T = Variable(torch.zeros(4, 4)) for i in range(4): for j in range(4): T[i, j] = utils.toeplitz.toeplitz_getitem(c, c, i, j) U = torch.randn(4, 4).triu() U = Variable(U.mul(U.diag().sign().unsqueeze(1).expand_as(U).triu()), requires_grad=True) mu_diff = Variable(torch.randn(4), requires_grad=True) actual = _det(T).log() + mu_diff.dot( T.inverse().mv(mu_diff)) + T.inverse().mm(U.t().mm(U)).trace() actual.backward() actual_c_grad = c.grad.data.clone() actual_mu_diff_grad = mu_diff.grad.data.clone() actual_U_grad = U.grad.data.clone() c.grad.data.fill_(0) mu_diff.grad.data.fill_(0) U.grad.data.fill_(0) def _matmul_closure_factory(*args): c, = args return lambda mat2: sym_toeplitz_matmul(c, mat2) def _derivative_quadratic_form_factory(*args): return lambda left_vector, right_vector: ( sym_toeplitz_derivative_quadratic_form(left_vector, right_vector ), ) covar_args = (c, ) gpytorch.functions.num_trace_samples = 1000 res = trace_logdet_quad_form_factory(_matmul_closure_factory, _derivative_quadratic_form_factory)()( mu_diff, U, *covar_args) res.backward() res_c_grad = c.grad.data res_mu_diff_grad = mu_diff.grad.data res_U_grad = U.grad.data assert (res.data - actual.data).norm() / actual.data.norm() < 0.15 assert (res_c_grad - actual_c_grad).norm() / actual_c_grad.norm() < 0.15 assert (res_mu_diff_grad - actual_mu_diff_grad).norm() / actual_mu_diff_grad.norm() < 1e-3 assert (res_U_grad - actual_U_grad).norm() / actual_U_grad.norm() < 1e-3 c.grad.data.fill_(0) mu_diff.grad.data.fill_(0) U.grad.data.fill_(0) covar_args = (c, ) gpytorch.functions.fastest = False res = trace_logdet_quad_form_factory(_matmul_closure_factory, _derivative_quadratic_form_factory)()( mu_diff, U, *covar_args) res.backward() res_c_grad = c.grad.data res_mu_diff_grad = mu_diff.grad.data res_U_grad = U.grad.data assert (res.data - actual.data).norm() / actual.data.norm() < 1e-3 assert (res_c_grad - actual_c_grad).norm() / actual_c_grad.norm() < 1e-3 assert (res_mu_diff_grad - actual_mu_diff_grad).norm() / actual_mu_diff_grad.norm() < 1e-3 assert (res_U_grad - actual_U_grad).norm() / actual_U_grad.norm() < 1e-3