Beispiel #1
0
    def testBasic(self):
        f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')

        shape = (xla_bridge.device_count(), 4)
        x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
        expected = x - onp.sum(x, 0)

        ans = f(x)
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #2
0
 def mapped_update(i, opt_state, batch, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = num_devices.
   _, opt_update = optimizer(lr_fun)
   params = trax_opt.get_params(opt_state)
   grads = backend.grad(loss_fun)(params, batch, predict_fun, rng)
   grads = jax.tree_util.tree_map(
       lambda g: lax.psum(g, "batch"), grads)
   return opt_update(i, grads, opt_state)
Beispiel #3
0
 def mapped_update(i, opt_state, batch, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     params, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng)
     grads = backend.grad(loss_fn)(params, batch, predict_fn, rng)
     grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads)
     return optimizer.tree_update(i, grads, params, slots,
                                  opt_params), subrng
Beispiel #4
0
def update(params, opt_state, x, y_true):
    # calc grads; summed across devices
    loss, grads = value_and_grad(mean_cross_entropy)(params, x, y_true)
    grads = tree_map(lambda v: psum(v, 'device'), grads)
    # apply update
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    # return new states & mean loss
    return params, opt_state, loss.mean()
Beispiel #5
0
def _axis_size(a, axis):
    if not isinstance(axis, (tuple, list)):
        axis = (axis, )
    size = 1
    a_shape = np.shape(a)
    for a in axis:
        size *= maybe_named_axis(a, lambda i: a_shape[i],
                                 lambda name: lax.psum(1, name))
    return size
Beispiel #6
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
Beispiel #7
0
 def testResourceConflictArgs(self):
   fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
             in_axes=['a', 'b'], out_axes=[],
             axis_resources={'a': 'x', 'b': 'x'})
   x = np.arange(16).reshape(4, 4)
   error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
            r"coincide in the named_shape of an input to an xmapped function "
            r"<lambda>")
   with self.assertRaisesRegex(JAXTypeError, error):
     fm(x)
Beispiel #8
0
 def testLoopCollectives(self):
   fm = xmap(lambda x: lax.psum(x, 'i'),
             in_axes=['i'], out_axes=[],
             axis_resources={'i': 'l'})
   x = np.arange(16)
   error = (r"Named axes with loop resources assigned to them cannot be "
            r"referenced inside the xmapped computation \(e.g. in "
            r"collectives\), but `i` violates that rule")
   with self.assertRaisesRegex(RuntimeError, error):
     fm(x)
Beispiel #9
0
def train_step(optimizer, inputs, labels, learning_rate_fn, dropout_rng=None):
  """A single training step.

  Args:
    optimizer: optimizer used for training
    inputs: inputs to the model [word_ids, mask, type_ids]
    labels: target output [start_positions, end_positions]
    learning_rate_fn: function for tuning learning rate
    dropout_rng: random seed used for dropout

  Returns:
    new_optimizer: updated model optimizer after training step
    loss: sparse categorical crossentropy
    new_dropout_rng: new random seed to be used for next step
  """
  dropout_rng, new_dropout_rng = random.split(dropout_rng)

  def loss_fn(model):
    with nn.stochastic(dropout_rng):
      use_bf16 = FLAGS.use_bfloat16_activation
      dtype = jnp.bfloat16 if use_bf16 else jnp.float32
      lm_outputs, sentence_outputs = model(
          inputs, train=True, dtype=dtype)
      assert lm_outputs.dtype == jnp.float32
      assert sentence_outputs.dtype == jnp.float32
    total_loss, lm_loss, sentence_loss = get_pretrain_loss(
        labels, lm_outputs, sentence_outputs)
    return total_loss, (lm_loss, sentence_loss)

  def clip_by_global_normal(grads):
    _, treedef = jax.tree_flatten(grads)
    grads_flat = treedef.flatten_up_to(grads)
    grad_norms = [jnp.linalg.norm(gd)**2 for gd in grads_flat]
    global_norm = jnp.sqrt(jnp.sum(grad_norms))
    clip_norm = 1.0
    grads_flat = [
        gd * clip_norm / jnp.maximum(global_norm, clip_norm)
        for gd in grads_flat
    ]
    return jax.tree_unflatten(treedef, grads_flat)

  step = optimizer.state[0].step
  lr = learning_rate_fn(step)
  total_loss, (lm_loss,
               sentence_loss), grads = optimizer.compute_gradient(loss_fn)
  clipped_grads = clip_by_global_normal(grads)
  if FLAGS.reduce_gradients_in_bf16:
    clipped_grads = jax.tree_map(lambda x: x.astype(jnp.bfloat16),
                                 clipped_grads)
  clipped_grads = lax.psum(clipped_grads, 'batch')
  if FLAGS.reduce_gradients_in_bf16:
    clipped_grads = jax.tree_map(lambda x: x.astype(jnp.float32), clipped_grads)
  new_optimizer = optimizer.apply_gradient(clipped_grads, learning_rate=lr)

  return new_optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng
Beispiel #10
0
  def DISABLED_testSum(self):
    pfun, axis_name = papply(np.sum, 5)

    jaxpr = make_jaxpr(pfun)(onp.zeros(5))
    expected_jaxpr = make_jaxpr(
        lambda x: lax.psum(x, axis_name))(onp.zeros(5))
    assert repr(jaxpr) == repr(expected_jaxpr)

    ans = serial_pmap(pfun, axis_name)(onp.arange(3.))
    expected = onp.sum(onp.arange(3.))
    self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #11
0
 def testCollectiveReduce(self):
   fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
             in_axes=[['a', 'b', ...], {0: 'c'}],
             out_axes=[['b', ...], {0: 'c'}],
             axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
   ashape = (16, 8, 5)
   a = jnp.arange(np.prod(ashape)).reshape(ashape)
   bshape = (2, 7)
   b = jnp.arange(np.prod(bshape)).reshape(bshape)
   c, d = fm(a, b)
   self.assertAllClose(c, (a * 2).sum(0))
   self.assertAllClose(d, b * 4)
Beispiel #12
0
  def testJitPmapComposition(self):
    f = lambda x: x - lax.psum(x, 'i')

    shape = (xla_bridge.device_count(), 4)
    x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
    expected = x - onp.sum(x, 0)

    ans = jit(pmap(f, 'i'))(x)
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = pmap(jit(f), 'i')(x)
    self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #13
0
    def testSum(self):
        pfun, axis_name = _papply(lambda x: np.sum(x, axis=0))

        jaxpr = make_jaxpr(pfun)(onp.ones(3))
        expected_jaxpr = make_jaxpr(lambda x: lax.psum(x, axis_name))(
            onp.zeros((5, 3)))
        assert repr(jaxpr) == repr(expected_jaxpr)

        arg = onp.arange(15.).reshape((5, 3))
        ans = soft_pmap(pfun, axis_name)(arg)[0]
        expected = onp.sum(arg, axis=0)
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #14
0
  def testPartiallyMappedNested(self, device_mesh_shape):
    mesh_shape = self._getMeshShape(device_mesh_shape)

    f = pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
    f = pmap(f, axis_name='j', in_axes=(None, 0))

    x = 3.
    y = onp.arange(prod(mesh_shape), dtype=onp.float32).reshape(mesh_shape)
    expected = onp.broadcast_to(x - onp.sum(y, 1, keepdims=True), mesh_shape)

    ans = f(x, y)
    self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #15
0
 def test_pmap(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
         x = np.arange(jax.device_count(), dtype=np.int64)
         f(x)
         files_in_directory = len(os.listdir(tmpdir))
         self.assertEqual(files_in_directory, 1)
         x = np.arange(jax.device_count(), dtype=np.float32)
         f(x)
         files_in_directory = len(os.listdir(tmpdir))
         self.assertEqual(files_in_directory, 2)
Beispiel #16
0
def _evaluate_batch(flax_module, batch_stats, batch, metrics_bundle,
                    apply_one_hot_in_loss):
    """Evaluates metrics on the given batch.

  Currently we assume each metric_fn in metrics_bundle has the API:
    metric_fn(logits, targets, weights)
  and returns an array of shape [batch_size]. We also assume that to compute
  the aggregate metric, one should sum across all batches, then divide by the
  total samples seen (calculated by the 'denominator' metric). In this way we
  currently only support metrics of the 1/N sum f(inputs, targets). Note, the
  caller is responsible for dividing by metrics['denominator'] when computing
  the mean of each metric.

  Args:
    flax_module: A flax.nn.Module
    batch_stats: A flax.nn.Collection object tracking batch_stats.
    batch: A dictionary with keys 'inputs', 'targets', 'weights'.
    metrics_bundle: A group of metrics to use for evaluation.
    apply_one_hot_in_loss: Indicates whether or not the targets are one hot
      encoded.

  Returns:
    A dictionary with the same keys as metrics, but mapping to the summed metric
    across the sharded batch_dim.

  """
    with nn.stateful(batch_stats, mutable=False):
        logits = flax_module(batch['inputs'], train=False)
    targets = batch['targets']

    if apply_one_hot_in_loss:
        targets = one_hot(batch['targets'], logits.shape[-1])

    # map the dict values (which are functions) to function(targets, logits)
    weights = batch.get('weights')  # Weights might not be defined.
    eval_batch_size = targets.shape[0]
    if weights is None:
        weights = jnp.ones(eval_batch_size)

    # This psum is required to correctly evaluate with multihost. Only host 0
    # will report the metrics, so we must aggregate across all hosts. The psum
    # will map an array of shape [n_global_devices, batch_size] -> [batch_size]
    # by summing across the devices dimension. The outer sum then sums across the
    # batch dim. The result is the we have summed across all samples in the
    # sharded batch.

    evaluated_metrics = {}
    for key in metrics_bundle:
        per_example_metrics = metrics_bundle[key](logits, targets, weights)
        evaluated_metrics[key] = jnp.sum(
            lax.psum(per_example_metrics, axis_name='batch'))

    return evaluated_metrics
Beispiel #17
0
    def allreduce_spmd_update(i, op_state, batch):

        #params = tree_unflatten(treedef, params)
        params = get_params(op_state)
        grads = grad(loss)(params, batch)
        leaves, local_treedef = tree_flatten(grads)
        # We compute the total gradients, summing across the device-mapped axis,
        # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum.
        grads = [lax.psum(dw, 'batch') for dw in leaves]
        grads = tree_unflatten(local_treedef, grads)
        op_state = opt_update(i, grads, op_state)

        return op_state
Beispiel #18
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float64,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
Beispiel #19
0
    def testLogSoftmax(self):
        def fun(x):
            return x - np.log(np.sum(np.exp(x)))

        pfun, axis_name = papply(fun)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(
            lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = pmap(pfun, axis_name)(onp.arange(1., 5.))
        expected = fun(onp.arange(1., 5.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #20
0
  def testBadAxisSizeError(self):
    if xla_bridge.device_count() == 1:
      raise SkipTest("this test requires multiple devices")

    f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i',
             devices=xla_bridge.devices())
    with self.assertRaisesRegex(
        ValueError, r"compiling computation that requires 1 replicas, "
        r"but \d+ devices were specified"):
      f(np.ones(1))

    with self.assertRaisesRegex(
        ValueError, r"compiling computation that requires \d+ replicas, "
        r"but \d+ devices were specified"):
      f(np.ones(xla_bridge.device_count() + 1))
Beispiel #21
0
 def testNestedMeshSPMD(self):
   h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
            in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
            axis_resources={'c': 'z'})
   f = xmap(lambda x: h(x * 2),
            in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}),
            axis_resources={'a': 'x', 'b': 'y'})
   xshape = (8, 2, 4, 5)
   x = jnp.arange(np.prod(xshape)).reshape(xshape)
   y = f(x)
   hlo = jax.xla_computation(f)(x).as_hlo_text()
   match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
   self.assertIsNot(match, None)
   tile_factors = [int(s) for s in match.group(1).split(',')]
   self.assertEqual(set(tile_factors), {1, 2})
Beispiel #22
0
    def testLogSoftmax(self):
        raise SkipTest("test doesn't pass yet")  # TODO(frostig)

        def fun(x):
            return x - np.log(np.sum(np.exp(x)))

        pfun, axis_name = _papply(fun)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(
            lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = soft_pmap(pfun, axis_name)(onp.arange(1., 5.))
        expected = fun(onp.arange(1., 5.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #23
0
  def testBadAxisSizeError(self):
    if xla_bridge.device_count() == 1:
      raise SkipTest("this test requires multiple devices")

    f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i',
             devices=xla_bridge.devices())
    with self.assertRaisesRegex(
        ValueError, r"Leading axis size of input to pmapped function must "
        r"equal the number of local devices passed to pmap. Got axis_size=1, "
        r"num_local_devices=\d."):
      f(np.ones(1))

    with self.assertRaisesRegex(
        ValueError, r"Leading axis size of input to pmapped function must "
        r"equal the number of local devices passed to pmap. Got axis_size=\d, "
        r"num_local_devices=\d."):
      f(np.ones(xla_bridge.device_count() + 1))
Beispiel #24
0
 def _eval_model(
         self, params: spec.ParameterContainer, batch: Dict[str,
                                                            spec.Tensor],
         model_state: spec.ModelAuxiliaryState, rng: spec.RandomState
 ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
     logits, _ = self.model_fn(params,
                               batch,
                               model_state,
                               spec.ForwardPassMode.EVAL,
                               rng,
                               update_batch_norm=False)
     accuracy = jnp.sum(jnp.argmax(logits, axis=-1) == batch['targets'])
     loss = jnp.sum(self.loss_fn(batch['targets'], logits))
     num_data = len(logits)
     metrics = {'accuracy': accuracy, 'loss': loss, 'num_data': num_data}
     metrics = lax.psum(metrics, axis_name='batch')
     return metrics
Beispiel #25
0
  def testPsumMultiple(self):
    f = lambda x: lax.psum(x, ('i', 'j'))
    f = pmap(pmap(f, 'i'), 'j')

    def sum_and_broadcast(x, axis):
      return onp.repeat(onp.sum(x, axis, keepdims=True), x.shape[axis], axis)

    device_count = xla_bridge.device_count()
    num_pairs, ragged = divmod(device_count, 2)
    if num_pairs > 1 and not ragged:
      shape = (num_pairs, 2, 4)
    else:
      shape = (device_count, 1, 4)
    x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

    ans = f(x)
    expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
    self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #26
0
  def testPartiallyMapped(self):
    f = pmap(lambda x, y: x, in_axes=(None, 0))
    g = pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))

    mesh_shape = (xla_bridge.device_count(),)
    shape = mesh_shape + (4,)
    x = onp.array(3., dtype=onp.float32)
    y = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

    f_expected = onp.broadcast_to(x, mesh_shape)
    f_ans = f(x, y)
    self.assertAllClose(f_ans, f_expected, check_dtypes=True)
    self.assertIsInstance(f_ans, pxla.ShardedDeviceArray)
    # the output is actually replicated (has the same values in each device buffer)
    # but out_axes is implicitly 0, so we shouldn't have replication in the
    # sharding spec.
    self.assertEqual(f_ans.sharding_spec.replication_factor, 1)

    g_expected = onp.broadcast_to(x - onp.sum(y, 0, keepdims=True), shape)
    g_ans = g(x, y)
    self.assertAllClose(g_ans, g_expected, check_dtypes=True)
    self.assertIsInstance(g_ans, pxla.ShardedDeviceArray)
    self.assertEqual(g_ans.sharding_spec.replication_factor, 1)
Beispiel #27
0
def _one_hot(x: Array, num_classes: int, *, dtype: Any,
             axis: Union[int, AxisName]) -> Array:
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)  # type: ignore[arg-type]
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
    return jnp.asarray(lhs == rhs, dtype=dtype)
Beispiel #28
0
 def testIssue804(self):
     num_devices = xla_bridge.device_count()
     f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i"), c), 0.)
     api.pmap(f, axis_name="i")(np.ones((num_devices, 4)))  # doesn't crash
Beispiel #29
0
def get_axis_size(axis_name=None):
  if JAX_MODE:
    return lax.psum(1, axis_name)
  ctx = tf.distribute.get_replica_context()
  return ctx.num_replicas_in_sync
Beispiel #30
0
def _sum_seeds_pmapped(seed):
    return lax.psum(seed, 'hosts')