Exemple #1
0
  def testGather(self):
    f = pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')

    shape = (xla_bridge.device_count(), 4)
    x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
    expected = onp.array([x] * xla_bridge.device_count())
    ans = f(x)
    self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #2
0
 def inner_pmap(params, samples):
   """We need an inner function as only JAX types can be passed to a pmap."""
   samples_enc = model_utils.posenc(samples, scene_params['_min_deg_point'],
                                    scene_params['_max_deg_point'],
                                    scene_params['_legacy_posenc_order'])
   raw_rgb_features, raw_sigma = mlp_model.apply(params, samples_enc)
   rgb_features = rgb_activation(raw_rgb_features)
   sigma = sigma_activation(raw_sigma)
   return lax.all_gather((rgb_features, sigma), axis_name='batch')
Exemple #3
0
#let's test out the jax pmap functionality

from jax import random, pmap, lax
import jax.numpy as jnp

# Create n random 5000 x 6000 matrices, one per GPU
n = 2
keys = random.split(random.PRNGKey(0), n)

print("keys", keys)

mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(
    mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

#try to all_gather
y = pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')(result)

print("all_gather:")
print(y.shape)
Exemple #4
0
 def testCollectiveAllGather(self):
   x = jnp.arange(4)
   result = xmap(lambda x: lax.all_gather(x, 'i') + lax.axis_index('i'),
                 in_axes=['i', ...], out_axes=['i', ...])(x)
   self.assertAllClose(result, x + x[jnp.newaxis].T)
Exemple #5
0
 def f(x):
   a = lax.all_gather(jnp.ones_like(x), axis_name='i')
   b = lax.all_gather(1, axis_name='i')
   return a, b
Exemple #6
0
 def f(x):
   return lax.all_gather(x, axis_name='i')