コード例 #1
0
 def test_batch_eval_neg_holder_table(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         X = torch.zeros(2, 2, device=device, dtype=dtype)
         res = neg_holder_table(X)
         self.assertEqual(res.dtype, dtype)
         self.assertEqual(res.device.type, device.type)
         self.assertEqual(res.shape, torch.Size([2]))
         self.assertTrue(res.abs().sum().item() < 1e-6)
コード例 #2
0
 def test_neg_holder_table_global_maxima(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         X = torch.tensor(GLOBAL_MAXIMIZERS,
                          device=device,
                          dtype=dtype,
                          requires_grad=True)
         res = neg_holder_table(X)
         torch.autograd.backward([*res])
         self.assertTrue(torch.max((res - GLOBAL_MAXIMUM).abs()) < 1e-5)
         self.assertLess(X.grad.abs().max().item(), 1e-3)