Exemple #1
0
    def test_batchapply_accepts_float(self):
        def raises(a, b):
            if len(a.shape) != 2:
                raise ValueError("a must be shape 2")
            return a + b

        out = basic.BatchApply(raises)(jnp.ones([2, 3, 4]), 2.)
        np.testing.assert_array_equal(out, 3 * jnp.ones([2, 3, 4]))
Exemple #2
0
    def test_batchapply_accepts_none(self):
        def raises(a, b):
            if a is not None:
                raise ValueError("a must be None.")
            if len(b.shape) != 2:
                raise ValueError("b must be shape 2")
            return 3 * b

        out = basic.BatchApply(raises)(None, jnp.ones([2, 3, 4]))
        np.testing.assert_array_equal(out, 3 * jnp.ones([2, 3, 4]))
Exemple #3
0
    def test_batchapply(self):
        def raises(a, b):
            if len(a.shape) != 2:
                raise ValueError("a must be shape 2")
            if len(b.shape) != 1:
                raise ValueError("b must be shape 1")
            return a + b

        out = basic.BatchApply(raises)(jnp.ones([2, 3, 4]), jnp.ones([4]))
        np.testing.assert_array_equal(out, 2 * jnp.ones([2, 3, 4]))
Exemple #4
0
 def test_batchapply_raises(self):
     with self.assertRaisesRegex(ValueError, "requires at least one input"):
         basic.BatchApply(lambda: 1)()