def testInfeedThenOutfeedInALoop(self): hcb.stop_outfeed_receiver() def doubler(_, token): y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), jnp.float32)) return lax.outfeed(token, y * np.float32(2)) @jax.jit def f(n): token = lax.create_token(n) token = lax.fori_loop(0, n, doubler, token) return n device = jax.local_devices()[0] n = 10 execution = threading.Thread(target=lambda: f(n)) execution.start() for _ in range(n): x = np.random.randn(3, 4).astype(np.float32) device.transfer_to_infeed((x,)) y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,)) .with_major_to_minor_layout_if_absent()) self.assertAllClose(y, x * np.float32(2)) execution.join()
def testInfeedThenOutfeed(self): hcb.stop_outfeed_receiver() @jax.jit def f(x): token = lax.create_token(x) y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), np.float32)) token = lax.outfeed(token, y + onp.float32(1)) return lax.tie_in(token, x - 1) x = onp.float32(7.5) y = onp.random.randn(3, 4).astype(onp.float32) execution = threading.Thread(target=lambda: f(x)) execution.start() device = jax.local_devices()[0] device.transfer_to_infeed((y,)) out, = device.transfer_from_outfeed( xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent()) execution.join() self.assertAllClose(out, y + onp.float32(1))
def tearDownClass(cls): hcb.stop_outfeed_receiver()