Exemple #1
0
 def test_assert_axis_dimension_fail(self):
   tensor = jnp.ones((3, 2, 7, 2))
   for i in range(-tensor.ndim, tensor.ndim):
     s = tensor.shape[i]
     with self.assertRaisesRegex(
         AssertionError, 'Expected tensor to have dimension'):
       asserts.assert_axis_dimension(tensor, axis=i, expected=s+1)
Exemple #2
0
 def test_assert_axis_dimension_axis_invalid(self):
   tensor = jnp.ones((3, 2))
   for i in (2, -3):
     with self.assertRaisesRegex(AssertionError, 'not available'):
       asserts.assert_axis_dimension(tensor, axis=i, expected=1)
Exemple #3
0
 def test_assert_axis_dimension_fail(self):
     tensor = jnp.ones((3, 2, 7, 2))
     for i, s in enumerate(tensor.shape):
         with self.assertRaisesRegex(AssertionError,
                                     'Expected tensor to have dimension'):
             asserts.assert_axis_dimension(tensor, axis=i, expected=s + 1)
Exemple #4
0
 def test_assert_axis_dimension_pass(self):
   tensor = jnp.ones((3, 2, 7, 2))
   for i in range(-tensor.ndim, tensor.ndim):
     s = tensor.shape[i]
     asserts.assert_axis_dimension(tensor, axis=i, expected=s)
Exemple #5
0
 def test_assert_axis_dimension_pass(self):
     tensor = jnp.ones((3, 2, 7, 2))
     for i, s in enumerate(tensor.shape):
         asserts.assert_axis_dimension(tensor, axis=i, expected=s)