def testVmap(self): fn1 = extensions.vmap(lambda z: z * z) x = tf_np.arange(10) self.assertAllClose(x * x, fn1(x)) y = tf.range(10) np_y = tf_np.asarray(y) output = fn1(y) self.assertIsInstance(output, tf_np.ndarray) self.assertAllClose(np_y * np_y, output) fn2 = extensions.vmap(lambda x, y: x + y) x = tf_np.random.randn(10, 3) y = tf_np.random.randn(10, 2, 3) self.assertAllClose(tf_np.expand_dims(x, 1) + y, fn2(x, y))
def test_vmap_out_axes(self): f = extensions.vmap(lambda x: x, out_axes=0) inp = tf_np.arange(6).reshape([2, 3]) self.assertAllClose(inp, f(inp)) self.assertAllClose([inp, inp], f((inp, inp))) f = extensions.vmap(lambda x: x, out_axes=-1) self.assertAllClose(inp.T, f(inp)) f = extensions.vmap(lambda x: x, out_axes=None) self.assertAllClose(inp[0], f(inp)) f = extensions.vmap(lambda x: x, out_axes=([0], (-1, None), {'a': 1})) a, b, c = f(([inp], (inp, inp), {'a': inp})) self.assertAllClose([inp], a) self.assertAllClose((inp.T, inp[0]), b) self.assertAllClose(inp.T, c['a'])
def test_vmap_unbatched_object_passthrough_issue_183(self): # https://github.com/google/jax/issues/183 fun = lambda f, x: f(x) vfun = extensions.vmap(fun, (None, 0)) ans = vfun(lambda x: x + 1, tf_np.arange(3)) self.assertAllClose(ans, np.arange(1, 4))