def test_computes_diag_eval(self): a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1) variances = torch.randn(3) kernel = WhiteNoiseKernel(variances=variances) kernel.eval() actual = torch.diag(variances) res = kernel(a).evaluate() self.assertLess(torch.norm(res - actual), 1e-5)
def test_computes_diag_eval_batch(self): a = torch.Tensor([[4, 2, 8], [4, 2, 8]]).view(2, 3, 1) variances = torch.randn(2, 3, 1) kernel = WhiteNoiseKernel(variances=variances) kernel.eval() actual = torch.cat((torch.diag(variances[0].squeeze(-1)).unsqueeze(0), torch.diag(variances[1].squeeze(-1)).unsqueeze(0))) res = kernel(a).evaluate() self.assertLess(torch.norm(res - actual), 1e-5)
def test_computes_zero_eval(self): a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1) b = torch.tensor([3, 7], dtype=torch.float).view(2, 1) variances = torch.randn(3) kernel = WhiteNoiseKernel(variances=variances) kernel.eval() actual_one = torch.zeros(3, 2) actual_two = torch.zeros(2, 3) res_one = kernel(a, b).evaluate() res_two = kernel(b, a).evaluate() self.assertLess(torch.norm(res_one - actual_one), 1e-5) self.assertLess(torch.norm(res_two - actual_two), 1e-5)