Esempio n. 1
0
    def test_expand_apply(self):
        def raises(a, b):
            if len(a.shape) != 3:
                raise ValueError("a must be shape 3")
            if len(b.shape) != 2:
                raise ValueError("b must be shape 2")
            return a + b

        out = basic.expand_apply(raises)(jnp.ones([3, 4]), jnp.ones([4]))
        np.testing.assert_array_equal(out, 2 * jnp.ones([3, 4]))
Esempio n. 2
0
 def test_expand_apply_raises(self):
     with self.assertRaisesRegex(ValueError,
                                 "only supports axis=0 or axis=-1"):
         basic.expand_apply(lambda: 1, axis=1)()