Beispiel #1
0
  def testInfeedThenOutfeedInALoop(self):
    hcb.stop_outfeed_receiver()

    def doubler(_, token):
      y, token = lax.infeed(
          token, shape=jax.ShapedArray((3, 4), jnp.float32))
      return lax.outfeed(token, y * np.float32(2))

    @jax.jit
    def f(n):
      token = lax.create_token(n)
      token = lax.fori_loop(0, n, doubler, token)
      return n

    device = jax.local_devices()[0]
    n = 10
    execution = threading.Thread(target=lambda: f(n))
    execution.start()
    for _ in range(n):
      x = np.random.randn(3, 4).astype(np.float32)
      device.transfer_to_infeed((x,))
      y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,))
                                        .with_major_to_minor_layout_if_absent())
      self.assertAllClose(y, x * np.float32(2))
    execution.join()
Beispiel #2
0
  def testInfeedThenOutfeed(self):
    hcb.stop_outfeed_receiver()
    @jax.jit
    def f(x):
      token = lax.create_token(x)
      y, token = lax.infeed(
          token, shape=jax.ShapedArray((3, 4), np.float32))
      token = lax.outfeed(token, y + onp.float32(1))
      return lax.tie_in(token, x - 1)

    x = onp.float32(7.5)
    y = onp.random.randn(3, 4).astype(onp.float32)
    execution = threading.Thread(target=lambda: f(x))
    execution.start()
    device = jax.local_devices()[0]
    device.transfer_to_infeed((y,))
    out, = device.transfer_from_outfeed(
      xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent())
    execution.join()
    self.assertAllClose(out, y + onp.float32(1))
Beispiel #3
0
 def tearDownClass(cls):
     hcb.stop_outfeed_receiver()