def test_initialize_q_batch_largeZ(self): for dtype in (torch.float, torch.double): # testing large eta*Z X = torch.rand(5, 3, 4, device=self.device, dtype=dtype) Y = torch.tensor([-1e12, 0, 0, 0, 1e12], device=self.device, dtype=dtype) ics = initialize_q_batch(X=X, Y=Y, n=2, eta=100) self.assertEqual(ics.shape[0], 2)
def test_initialize_q_batch(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): # basic test X = torch.rand(5, 3, 4, device=device, dtype=dtype) Y = torch.rand(5, device=device, dtype=dtype) ics = initialize_q_batch(X=X, Y=Y, n=2) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) self.assertEqual(ics.device, X.device) self.assertEqual(ics.dtype, X.dtype) # ensure nothing happens if we want all samples ics = initialize_q_batch(X=X, Y=Y, n=5) self.assertTrue(torch.equal(X, ics)) # ensure raises correct warning Y = torch.zeros(5, device=device, dtype=dtype) with warnings.catch_warnings(record=True) as w: ics = initialize_q_batch(X=X, Y=Y, n=2) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) with self.assertRaises(RuntimeError): initialize_q_batch(X=X, Y=Y, n=10)