def test_assert_axis_dimension_fail(self): tensor = jnp.ones((3, 2, 7, 2)) for i in range(-tensor.ndim, tensor.ndim): s = tensor.shape[i] with self.assertRaisesRegex( AssertionError, 'Expected tensor to have dimension'): asserts.assert_axis_dimension(tensor, axis=i, expected=s+1)
def test_assert_axis_dimension_axis_invalid(self): tensor = jnp.ones((3, 2)) for i in (2, -3): with self.assertRaisesRegex(AssertionError, 'not available'): asserts.assert_axis_dimension(tensor, axis=i, expected=1)
def test_assert_axis_dimension_fail(self): tensor = jnp.ones((3, 2, 7, 2)) for i, s in enumerate(tensor.shape): with self.assertRaisesRegex(AssertionError, 'Expected tensor to have dimension'): asserts.assert_axis_dimension(tensor, axis=i, expected=s + 1)
def test_assert_axis_dimension_pass(self): tensor = jnp.ones((3, 2, 7, 2)) for i in range(-tensor.ndim, tensor.ndim): s = tensor.shape[i] asserts.assert_axis_dimension(tensor, axis=i, expected=s)
def test_assert_axis_dimension_pass(self): tensor = jnp.ones((3, 2, 7, 2)) for i, s in enumerate(tensor.shape): asserts.assert_axis_dimension(tensor, axis=i, expected=s)