コード例 #1
0
ファイル: infeed_test.py プロジェクト: yuejiesong1900/jax
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), jnp.float32))
     token = lax.outfeed(token, y + np.float32(1))
     return x - 1 if config.omnistaging_enabled else lax.tie_in(
         token, x - 1)
コード例 #2
0
ファイル: pjit_test.py プロジェクト: rsepassi/jax
 def f(x):
   token = lax.create_token(x)
   token = lax.outfeed(token, x, partitions=(None,))
   token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),))
   token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
   return x
コード例 #3
0
ファイル: infeed_test.py プロジェクト: zhaowilliam/jax
 def doubler(_, token):
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   return lax.outfeed(token, y * np.float32(2))
コード例 #4
0
ファイル: infeed_test.py プロジェクト: zhaowilliam/jax
 def f(x):
   token = lax.create_token(x)
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   token = lax.outfeed(token, y + np.float32(1))
   return x - 1
コード例 #5
0
ファイル: infeed_test.py プロジェクト: yueyedeai/jax
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), np.float32))
     token = lax.outfeed(token, y + onp.float32(1))
     return lax.tie_in(token, x - 1)