コード例 #1
0
ファイル: pmap_test.py プロジェクト: mattwescott/jax
  def testTrees(self):
    ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
    def protate(x, axis_name):
      n = lax.psum(1, axis_name)
      return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)])

    tree_f = lambda f: partial(tree_util.tree_map, f)
    jax_f = lambda p: pmap(lambda x: p(x, 'i'), 'i')
    onp_f = lambda p: tree_f(lambda x: onp.broadcast_to(p(x, 0), x.shape))
    onp_transpose = tree_f(onp.transpose)
    onp_rotate = tree_f(lambda x: onp.concatenate([x[-1:], x[:-1]]))

    n = xla_bridge.device_count()
    x = {'a': onp.arange(1 * n * n, 2 * n * n).reshape([n, n]),
         'b': onp.arange(2 * n * n, 3 * n * n).reshape([n, n]),
         'c': onp.arange(4 * n * n, 5 * n * n).reshape([n, n])}

    assert_allclose = partial(tree_util.tree_multimap,
                              partial(self.assertAllClose, check_dtypes=False))
    assert_allclose(jax_f(lax.pmax)(x), onp_f(onp.max)(x))
    assert_allclose(jax_f(lax.pmin)(x), onp_f(onp.min)(x))
    assert_allclose(jax_f(lax.psum)(x), onp_f(onp.sum)(x))
    assert_allclose(jax_f(lax.pmean)(x), onp_f(onp.mean)(x))
    if jtu.device_under_test() not in ("cpu", "gpu"):
      # NOTE: all-to-all and ppermute only supported on TPU.
      assert_allclose(jax_f(ptranspose)(x), onp_transpose(x))
      assert_allclose(jax_f(protate)(x), onp_rotate(x))
コード例 #2
0
ファイル: batching_test.py プロジェクト: yuyuexi/jax
 def testAllToAll(self, vmap_axis, split_axis, concat_axis):
   shape = (4, 4, 4, 4)
   x = np.arange(np.prod(shape)).reshape(shape)
   f = vmap(lambda x: lax.all_to_all(x, 'i', split_axis, concat_axis),
            in_axes=vmap_axis, axis_name='i')
   y = f(x)
   ref = jnp.moveaxis(x, (vmap_axis, split_axis + (vmap_axis <= split_axis)),
                         (concat_axis + 1, 0))
   self.assertAllClose(y, ref)
コード例 #3
0
 def f(x):
   return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)
コード例 #4
0
ファイル: pmap_test.py プロジェクト: mattwescott/jax
 def f(x):
   return lax.all_to_all(x, 'i', 0, 0)