예제 #1
0
    def testVmap(self):
        fn1 = extensions.vmap(lambda z: z * z)

        x = tf_np.arange(10)
        self.assertAllClose(x * x, fn1(x))

        y = tf.range(10)
        np_y = tf_np.asarray(y)
        output = fn1(y)
        self.assertIsInstance(output, tf_np.ndarray)
        self.assertAllClose(np_y * np_y, output)

        fn2 = extensions.vmap(lambda x, y: x + y)
        x = tf_np.random.randn(10, 3)
        y = tf_np.random.randn(10, 2, 3)
        self.assertAllClose(tf_np.expand_dims(x, 1) + y, fn2(x, y))
예제 #2
0
    def test_vmap_out_axes(self):
        f = extensions.vmap(lambda x: x, out_axes=0)
        inp = tf_np.arange(6).reshape([2, 3])
        self.assertAllClose(inp, f(inp))
        self.assertAllClose([inp, inp], f((inp, inp)))

        f = extensions.vmap(lambda x: x, out_axes=-1)
        self.assertAllClose(inp.T, f(inp))

        f = extensions.vmap(lambda x: x, out_axes=None)
        self.assertAllClose(inp[0], f(inp))

        f = extensions.vmap(lambda x: x, out_axes=([0], (-1, None), {'a': 1}))
        a, b, c = f(([inp], (inp, inp), {'a': inp}))
        self.assertAllClose([inp], a)
        self.assertAllClose((inp.T, inp[0]), b)
        self.assertAllClose(inp.T, c['a'])
예제 #3
0
 def test_vmap_unbatched_object_passthrough_issue_183(self):
     # https://github.com/google/jax/issues/183
     fun = lambda f, x: f(x)
     vfun = extensions.vmap(fun, (None, 0))
     ans = vfun(lambda x: x + 1, tf_np.arange(3))
     self.assertAllClose(ans, np.arange(1, 4))