Esempio n. 1
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), jnp.float32))
     token = lax.outfeed(token, y + np.float32(1))
     return x - 1 if config.omnistaging_enabled else lax.tie_in(
         token, x - 1)
Esempio n. 2
0
 def f(x):
   token = lax.create_token(x)
   token = lax.outfeed(token, x, partitions=(None,))
   token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),))
   token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
   return x
Esempio n. 3
0
 def doubler(_, token):
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   return lax.outfeed(token, y * np.float32(2))
Esempio n. 4
0
 def f(x):
   token = lax.create_token(x)
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   token = lax.outfeed(token, y + np.float32(1))
   return x - 1
Esempio n. 5
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), np.float32))
     token = lax.outfeed(token, y + onp.float32(1))
     return lax.tie_in(token, x - 1)