def test_expand_shape_input_is_scalar_and_two_dims_expected(self): a = 3 expected = (1, 1) res = util.expand_shape(a, 2) self.assertEqual(expected, res)
def test_expand_shape_input_has_two_missing_dims(self): a = jnp.array([[3, 4], [5, 6], [7, 8]]) expected = (3, 2, 1, 1) res = util.expand_shape(a, 4) self.assertEqual(expected, res)
def test_expand_shape_input_is_scalar_and_one_dim_expected(self): a = 3 expected = (1,) res = util.expand_shape(a, 1) self.assertEqual(expected, res)
def test_expand_shape_input_has_more_dims(self): a = jnp.array([[3, 4], [5, 6], [7, 8]]) expected = (3, 2) res = util.expand_shape(a, 1) self.assertEqual(expected, res)