Ejemplo n.º 1
0
 def test_expand_bounds(self):
     X = torch.zeros(2, 3)
     expected_bounds = torch.zeros(2, 3)
     # bounds is float
     bounds = 0.0
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expected_bounds, expanded_bounds))
     # bounds is 0-d
     bounds = torch.tensor(0.0)
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expected_bounds, expanded_bounds))
     # bounds is 1-d
     bounds = torch.zeros(3)
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expected_bounds, expanded_bounds))
     # bounds is 2-d
     bounds = torch.zeros(1, 3)
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expected_bounds, expanded_bounds))
     # bounds is > 2-d
     bounds = torch.zeros(1, 1, 3)
     with self.assertRaises(RuntimeError):
         # X does not have a t-batch
         expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     X = torch.zeros(4, 2, 3)
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expanded_bounds, torch.zeros_like(X)))
     with self.assertRaises(RuntimeError):
         # bounds is not broadcastable to X
         expanded_bounds = _expand_bounds(bounds=torch.zeros(2, 1, 3), X=X)
     # bounds is None
     expanded_bounds = _expand_bounds(bounds=None, X=X)
     self.assertIsNone(expanded_bounds)
Ejemplo n.º 2
0
 def test_expand_bounds(self):
     X = torch.zeros(2, 3)
     expected_bounds = torch.zeros(1, 3)
     # bounds is float
     bounds = 0.0
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expected_bounds, expanded_bounds))
     # bounds is 0-d
     bounds = torch.tensor(0.0)
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expected_bounds, expanded_bounds))
     # bounds is 1-d
     bounds = torch.zeros(3)
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expected_bounds, expanded_bounds))
     # bounds is > 1-d
     bounds = torch.zeros(1, 3)
     expanded_bounds = _expand_bounds(bounds=bounds, X=X)
     self.assertTrue(torch.equal(expected_bounds, expanded_bounds))
     # bounds is None
     expanded_bounds = _expand_bounds(bounds=None, X=X)
     self.assertIsNone(expanded_bounds)