def test_computes_diag_eval(self):
     a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
     variances = torch.randn(3)
     kernel = WhiteNoiseKernel(variances=variances)
     kernel.eval()
     actual = torch.diag(variances)
     res = kernel(a).evaluate()
     self.assertLess(torch.norm(res - actual), 1e-5)
Пример #2
0
 def test_computes_diag_eval_batch(self):
     a = torch.Tensor([[4, 2, 8], [4, 2, 8]]).view(2, 3, 1)
     variances = torch.randn(2, 3, 1)
     kernel = WhiteNoiseKernel(variances=variances)
     kernel.eval()
     actual = torch.cat((torch.diag(variances[0].squeeze(-1)).unsqueeze(0),
                         torch.diag(variances[1].squeeze(-1)).unsqueeze(0)))
     res = kernel(a).evaluate()
     self.assertLess(torch.norm(res - actual), 1e-5)
 def test_computes_zero_eval(self):
     a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
     b = torch.tensor([3, 7], dtype=torch.float).view(2, 1)
     variances = torch.randn(3)
     kernel = WhiteNoiseKernel(variances=variances)
     kernel.eval()
     actual_one = torch.zeros(3, 2)
     actual_two = torch.zeros(2, 3)
     res_one = kernel(a, b).evaluate()
     res_two = kernel(b, a).evaluate()
     self.assertLess(torch.norm(res_one - actual_one), 1e-5)
     self.assertLess(torch.norm(res_two - actual_two), 1e-5)