Beispiel #1
0
 def test_initialize_q_batch_largeZ(self):
     for dtype in (torch.float, torch.double):
         # testing large eta*Z
         X = torch.rand(5, 3, 4, device=self.device, dtype=dtype)
         Y = torch.tensor([-1e12, 0, 0, 0, 1e12], device=self.device, dtype=dtype)
         ics = initialize_q_batch(X=X, Y=Y, n=2, eta=100)
         self.assertEqual(ics.shape[0], 2)
 def test_initialize_q_batch(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         # basic test
         X = torch.rand(5, 3, 4, device=device, dtype=dtype)
         Y = torch.rand(5, device=device, dtype=dtype)
         ics = initialize_q_batch(X=X, Y=Y, n=2)
         self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
         self.assertEqual(ics.device, X.device)
         self.assertEqual(ics.dtype, X.dtype)
         # ensure nothing happens if we want all samples
         ics = initialize_q_batch(X=X, Y=Y, n=5)
         self.assertTrue(torch.equal(X, ics))
         # ensure raises correct warning
         Y = torch.zeros(5, device=device, dtype=dtype)
         with warnings.catch_warnings(record=True) as w:
             ics = initialize_q_batch(X=X, Y=Y, n=2)
             self.assertEqual(len(w), 1)
             self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
         self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
         with self.assertRaises(RuntimeError):
             initialize_q_batch(X=X, Y=Y, n=10)