def test_inv_quad_log_det_many_vectors(self): # Forward pass actual_inv_quad = (self.mat_var_clone.inverse().matmul( self.vecs_var_clone).mul(self.vecs_var_clone).sum()) actual_log_det = self.mat_var_clone.det().log() with gpytorch.settings.num_trace_samples(1000): nlv = NonLazyVariable(self.mat_var) res_inv_quad, res_log_det = nlv.inv_quad_log_det( inv_quad_rhs=self.vecs_var, log_det=True) self.assertAlmostEqual(res_inv_quad.item(), actual_inv_quad.item(), places=1) self.assertAlmostEqual(res_log_det.item(), actual_log_det.item(), places=1) # Backward inv_quad_grad_output = torch.Tensor([3]) log_det_grad_output = torch.Tensor([4]) actual_inv_quad.backward(gradient=inv_quad_grad_output) actual_log_det.backward(log_det_grad_output) res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True) res_log_det.backward(gradient=log_det_grad_output) self.assertTrue( approx_equal(self.mat_var_clone.grad.data, self.mat_var.grad.data, epsilon=1e-1)) self.assertTrue( approx_equal(self.vecs_var_clone.grad.data, self.vecs_var.grad.data))
def test_inv_quad_log_det_many_vectors(self): # Forward pass actual_inv_quad = ( torch.cat([self.mats_var_clone[0].inverse().unsqueeze(0), self.mats_var_clone[1].inverse().unsqueeze(0)]) .matmul(self.vecs_var_clone) .mul(self.vecs_var_clone) .sum(2) .sum(1) ) actual_log_det = torch.cat( [self.mats_var_clone[0].det().log().unsqueeze(0), self.mats_var_clone[1].det().log().unsqueeze(0)] ) with gpytorch.settings.num_trace_samples(1000): nlv = NonLazyVariable(self.mats_var) res_inv_quad, res_log_det = nlv.inv_quad_log_det(inv_quad_rhs=self.vecs_var, log_det=True) self.assertTrue(approx_equal(res_inv_quad.data, actual_inv_quad.data, epsilon=1e-1)) self.assertTrue(approx_equal(res_log_det.data, actual_log_det.data, epsilon=1e-1)) # Backward inv_quad_grad_output = torch.Tensor([3, 4]) log_det_grad_output = torch.Tensor([4, 2]) actual_inv_quad.backward(gradient=inv_quad_grad_output) actual_log_det.backward(gradient=log_det_grad_output) res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True) res_log_det.backward(gradient=log_det_grad_output) self.assertTrue(approx_equal(self.mats_var_clone.grad.data, self.mats_var.grad.data, epsilon=1e-1)) self.assertTrue(approx_equal(self.vecs_var_clone.grad.data, self.vecs_var.grad.data))
def test_inv_quad_log_det_many_vectors(self): # Forward pass actual_inv_quad = torch.cat([ self.mats_var_clone[0].inverse().unsqueeze(0), self.mats_var_clone[1].inverse().unsqueeze(0), ]).matmul(self.vecs_var_clone).mul(self.vecs_var_clone).sum(2).sum(1) with gpytorch.settings.num_trace_samples(1000): nlv = NonLazyVariable(self.mats_var) res_inv_quad, res_log_det = nlv.inv_quad_log_det(inv_quad_rhs=self.vecs_var, log_det=True) for i in range(self.mats_var.size(0)): self.assert_scalar_almost_equal(res_inv_quad.data[i], actual_inv_quad.data[i], places=1) self.assert_scalar_almost_equal(res_log_det.data[i], self.log_dets[i], places=1) # Backward inv_quad_grad_output = torch.Tensor([3, 4]) log_det_grad_output = torch.Tensor([4, 2]) actual_inv_quad.backward(gradient=inv_quad_grad_output) mat_log_det_grad = torch.cat([ self.mats_var_clone[0].data.inverse().mul(log_det_grad_output[0]).unsqueeze(0), self.mats_var_clone[1].data.inverse().mul(log_det_grad_output[1]).unsqueeze(0), ]) self.mats_var_clone.grad.data.add_(mat_log_det_grad) res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True) res_log_det.backward(gradient=log_det_grad_output) self.assertTrue(approx_equal(self.mats_var_clone.grad.data, self.mats_var.grad.data, epsilon=1e-1)) self.assertTrue(approx_equal(self.vecs_var_clone.grad.data, self.vecs_var.grad.data))
def test_inv_quad_log_det_vector(self): # Forward pass actual_inv_quad = self.mat_var_clone.inverse().matmul(self.vec_var_clone).mul(self.vec_var_clone).sum() actual_log_det = self.mat_var_clone.det().log() with gpytorch.settings.num_trace_samples(1000): nlv = NonLazyVariable(self.mat_var) res_inv_quad, res_log_det = nlv.inv_quad_log_det(inv_quad_rhs=self.vec_var, log_det=True) self.assertAlmostEqual(res_inv_quad, actual_inv_quad, places=1) self.assertAlmostEqual(res_log_det.item(), actual_log_det.item(), places=1) # Backward actual_inv_quad.backward() actual_log_det.backward() res_inv_quad.backward(retain_graph=True) res_log_det.backward() self.assertTrue(approx_equal(self.mat_var_clone.grad.data, self.mat_var.grad.data, epsilon=1e-1)) self.assertTrue(approx_equal(self.vec_var_clone.grad.data, self.vec_var.grad.data))