def test_expand_apply(self): def raises(a, b): if len(a.shape) != 3: raise ValueError("a must be shape 3") if len(b.shape) != 2: raise ValueError("b must be shape 2") return a + b out = basic.expand_apply(raises)(jnp.ones([3, 4]), jnp.ones([4])) np.testing.assert_array_equal(out, 2 * jnp.ones([3, 4]))
def test_expand_apply_raises(self): with self.assertRaisesRegex(ValueError, "only supports axis=0 or axis=-1"): basic.expand_apply(lambda: 1, axis=1)()