예제 #1
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), jnp.float32))
     token = lax.outfeed(token, y + np.float32(1))
     return x - 1 if config.omnistaging_enabled else lax.tie_in(
         token, x - 1)
예제 #2
0
파일: pjit_test.py 프로젝트: rsepassi/jax
 def f(x):
   token = lax.create_token(x)
   token = lax.outfeed(token, x, partitions=(None,))
   token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),))
   token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
   return x
예제 #3
0
 def doubler(_, token):
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   return lax.outfeed(token, y * np.float32(2))
예제 #4
0
 def f(x):
   token = lax.create_token(x)
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   token = lax.outfeed(token, y + np.float32(1))
   return x - 1
예제 #5
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), np.float32))
     token = lax.outfeed(token, y + onp.float32(1))
     return lax.tie_in(token, x - 1)