def test_batchapply_accepts_float(self): def raises(a, b): if len(a.shape) != 2: raise ValueError("a must be shape 2") return a + b out = basic.BatchApply(raises)(jnp.ones([2, 3, 4]), 2.) np.testing.assert_array_equal(out, 3 * jnp.ones([2, 3, 4]))
def test_batchapply_accepts_none(self): def raises(a, b): if a is not None: raise ValueError("a must be None.") if len(b.shape) != 2: raise ValueError("b must be shape 2") return 3 * b out = basic.BatchApply(raises)(None, jnp.ones([2, 3, 4])) np.testing.assert_array_equal(out, 3 * jnp.ones([2, 3, 4]))
def test_batchapply(self): def raises(a, b): if len(a.shape) != 2: raise ValueError("a must be shape 2") if len(b.shape) != 1: raise ValueError("b must be shape 1") return a + b out = basic.BatchApply(raises)(jnp.ones([2, 3, 4]), jnp.ones([4])) np.testing.assert_array_equal(out, 2 * jnp.ones([2, 3, 4]))
def test_batchapply_raises(self): with self.assertRaisesRegex(ValueError, "requires at least one input"): basic.BatchApply(lambda: 1)()