def testBasic(self): f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i') shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) expected = x - onp.sum(x, 0) ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False)
def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = num_devices. _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) grads = backend.grad(loss_fun)(params, batch, predict_fun, rng) grads = jax.tree_util.tree_map( lambda g: lax.psum(g, "batch"), grads) return opt_update(i, grads, opt_state)
def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. params, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grads = backend.grad(loss_fn)(params, batch, predict_fn, rng) grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads) return optimizer.tree_update(i, grads, params, slots, opt_params), subrng
def update(params, opt_state, x, y_true): # calc grads; summed across devices loss, grads = value_and_grad(mean_cross_entropy)(params, x, y_true) grads = tree_map(lambda v: psum(v, 'device'), grads) # apply update updates, opt_state = opt.update(grads, opt_state, params) params = optax.apply_updates(params, updates) # return new states & mean loss return params, opt_state, loss.mean()
def _axis_size(a, axis): if not isinstance(axis, (tuple, list)): axis = (axis, ) size = 1 a_shape = np.shape(a) for a in axis: size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) return size
def threefry_random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) for name, size in shape.named_items: real_size = lax.psum(1, name) if real_size != size: raise ValueError( f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) key = threefry_fold_in(key, axis_index) size = prod(shape.positional) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1, )), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)))) bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape)
def testResourceConflictArgs(self): fm = xmap(lambda x: lax.psum(x, ('a', 'b')), in_axes=['a', 'b'], out_axes=[], axis_resources={'a': 'x', 'b': 'x'}) x = np.arange(16).reshape(4, 4) error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they " r"coincide in the named_shape of an input to an xmapped function " r"<lambda>") with self.assertRaisesRegex(JAXTypeError, error): fm(x)
def testLoopCollectives(self): fm = xmap(lambda x: lax.psum(x, 'i'), in_axes=['i'], out_axes=[], axis_resources={'i': 'l'}) x = np.arange(16) error = (r"Named axes with loop resources assigned to them cannot be " r"referenced inside the xmapped computation \(e.g. in " r"collectives\), but `i` violates that rule") with self.assertRaisesRegex(RuntimeError, error): fm(x)
def train_step(optimizer, inputs, labels, learning_rate_fn, dropout_rng=None): """A single training step. Args: optimizer: optimizer used for training inputs: inputs to the model [word_ids, mask, type_ids] labels: target output [start_positions, end_positions] learning_rate_fn: function for tuning learning rate dropout_rng: random seed used for dropout Returns: new_optimizer: updated model optimizer after training step loss: sparse categorical crossentropy new_dropout_rng: new random seed to be used for next step """ dropout_rng, new_dropout_rng = random.split(dropout_rng) def loss_fn(model): with nn.stochastic(dropout_rng): use_bf16 = FLAGS.use_bfloat16_activation dtype = jnp.bfloat16 if use_bf16 else jnp.float32 lm_outputs, sentence_outputs = model( inputs, train=True, dtype=dtype) assert lm_outputs.dtype == jnp.float32 assert sentence_outputs.dtype == jnp.float32 total_loss, lm_loss, sentence_loss = get_pretrain_loss( labels, lm_outputs, sentence_outputs) return total_loss, (lm_loss, sentence_loss) def clip_by_global_normal(grads): _, treedef = jax.tree_flatten(grads) grads_flat = treedef.flatten_up_to(grads) grad_norms = [jnp.linalg.norm(gd)**2 for gd in grads_flat] global_norm = jnp.sqrt(jnp.sum(grad_norms)) clip_norm = 1.0 grads_flat = [ gd * clip_norm / jnp.maximum(global_norm, clip_norm) for gd in grads_flat ] return jax.tree_unflatten(treedef, grads_flat) step = optimizer.state[0].step lr = learning_rate_fn(step) total_loss, (lm_loss, sentence_loss), grads = optimizer.compute_gradient(loss_fn) clipped_grads = clip_by_global_normal(grads) if FLAGS.reduce_gradients_in_bf16: clipped_grads = jax.tree_map(lambda x: x.astype(jnp.bfloat16), clipped_grads) clipped_grads = lax.psum(clipped_grads, 'batch') if FLAGS.reduce_gradients_in_bf16: clipped_grads = jax.tree_map(lambda x: x.astype(jnp.float32), clipped_grads) new_optimizer = optimizer.apply_gradient(clipped_grads, learning_rate=lr) return new_optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng
def DISABLED_testSum(self): pfun, axis_name = papply(np.sum, 5) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( lambda x: lax.psum(x, axis_name))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = serial_pmap(pfun, axis_name)(onp.arange(3.)) expected = onp.sum(onp.arange(3.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testCollectiveReduce(self): fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4), in_axes=[['a', 'b', ...], {0: 'c'}], out_axes=[['b', ...], {0: 'c'}], axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) ashape = (16, 8, 5) a = jnp.arange(np.prod(ashape)).reshape(ashape) bshape = (2, 7) b = jnp.arange(np.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(d, b * 4)
def testJitPmapComposition(self): f = lambda x: x - lax.psum(x, 'i') shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) expected = x - onp.sum(x, 0) ans = jit(pmap(f, 'i'))(x) self.assertAllClose(ans, expected, check_dtypes=False) ans = pmap(jit(f), 'i')(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testSum(self): pfun, axis_name = _papply(lambda x: np.sum(x, axis=0)) jaxpr = make_jaxpr(pfun)(onp.ones(3)) expected_jaxpr = make_jaxpr(lambda x: lax.psum(x, axis_name))( onp.zeros((5, 3))) assert repr(jaxpr) == repr(expected_jaxpr) arg = onp.arange(15.).reshape((5, 3)) ans = soft_pmap(pfun, axis_name)(arg)[0] expected = onp.sum(arg, axis=0) self.assertAllClose(ans, expected, check_dtypes=False)
def testPartiallyMappedNested(self, device_mesh_shape): mesh_shape = self._getMeshShape(device_mesh_shape) f = pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0)) f = pmap(f, axis_name='j', in_axes=(None, 0)) x = 3. y = onp.arange(prod(mesh_shape), dtype=onp.float32).reshape(mesh_shape) expected = onp.broadcast_to(x - onp.sum(y, 1, keepdims=True), mesh_shape) ans = f(x, y) self.assertAllClose(ans, expected, check_dtypes=False)
def test_pmap(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i') x = np.arange(jax.device_count(), dtype=np.int64) f(x) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) x = np.arange(jax.device_count(), dtype=np.float32) f(x) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2)
def _evaluate_batch(flax_module, batch_stats, batch, metrics_bundle, apply_one_hot_in_loss): """Evaluates metrics on the given batch. Currently we assume each metric_fn in metrics_bundle has the API: metric_fn(logits, targets, weights) and returns an array of shape [batch_size]. We also assume that to compute the aggregate metric, one should sum across all batches, then divide by the total samples seen (calculated by the 'denominator' metric). In this way we currently only support metrics of the 1/N sum f(inputs, targets). Note, the caller is responsible for dividing by metrics['denominator'] when computing the mean of each metric. Args: flax_module: A flax.nn.Module batch_stats: A flax.nn.Collection object tracking batch_stats. batch: A dictionary with keys 'inputs', 'targets', 'weights'. metrics_bundle: A group of metrics to use for evaluation. apply_one_hot_in_loss: Indicates whether or not the targets are one hot encoded. Returns: A dictionary with the same keys as metrics, but mapping to the summed metric across the sharded batch_dim. """ with nn.stateful(batch_stats, mutable=False): logits = flax_module(batch['inputs'], train=False) targets = batch['targets'] if apply_one_hot_in_loss: targets = one_hot(batch['targets'], logits.shape[-1]) # map the dict values (which are functions) to function(targets, logits) weights = batch.get('weights') # Weights might not be defined. eval_batch_size = targets.shape[0] if weights is None: weights = jnp.ones(eval_batch_size) # This psum is required to correctly evaluate with multihost. Only host 0 # will report the metrics, so we must aggregate across all hosts. The psum # will map an array of shape [n_global_devices, batch_size] -> [batch_size] # by summing across the devices dimension. The outer sum then sums across the # batch dim. The result is the we have summed across all samples in the # sharded batch. evaluated_metrics = {} for key in metrics_bundle: per_example_metrics = metrics_bundle[key](logits, targets, weights) evaluated_metrics[key] = jnp.sum( lax.psum(per_example_metrics, axis_name='batch')) return evaluated_metrics
def allreduce_spmd_update(i, op_state, batch): #params = tree_unflatten(treedef, params) params = get_params(op_state) grads = grad(loss)(params, batch) leaves, local_treedef = tree_flatten(grads) # We compute the total gradients, summing across the device-mapped axis, # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. grads = [lax.psum(dw, 'batch') for dw in leaves] grads = tree_unflatten(local_treedef, grads) op_state = opt_update(i, grads, op_state) return op_state
def one_hot(x: Array, num_classes: int, *, dtype: Any = jnp.float64, axis: Union[int, AxisName] = -1) -> Array: """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). axis: the axis or axes along which the function should be computed. """ num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError( f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) lhs = lax.expand_dims(x, (axis, )) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype), rhs_shape, (output_pos_axis, )) return jnp.asarray(lhs == rhs, dtype=dtype)
def testLogSoftmax(self): def fun(x): return x - np.log(np.sum(np.exp(x))) pfun, axis_name = papply(fun) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = pmap(pfun, axis_name)(onp.arange(1., 5.)) expected = fun(onp.arange(1., 5.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testBadAxisSizeError(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i', devices=xla_bridge.devices()) with self.assertRaisesRegex( ValueError, r"compiling computation that requires 1 replicas, " r"but \d+ devices were specified"): f(np.ones(1)) with self.assertRaisesRegex( ValueError, r"compiling computation that requires \d+ replicas, " r"but \d+ devices were specified"): f(np.ones(xla_bridge.device_count() + 1))
def testNestedMeshSPMD(self): h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))), in_axes={0: 'c'}, out_axes=({1: 'c'}, {}), axis_resources={'c': 'z'}) f = xmap(lambda x: h(x * 2), in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}), axis_resources={'a': 'x', 'b': 'y'}) xshape = (8, 2, 4, 5) x = jnp.arange(np.prod(xshape)).reshape(xshape) y = f(x) hlo = jax.xla_computation(f)(x).as_hlo_text() match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo) self.assertIsNot(match, None) tile_factors = [int(s) for s in match.group(1).split(',')] self.assertEqual(set(tile_factors), {1, 2})
def testLogSoftmax(self): raise SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x): return x - np.log(np.sum(np.exp(x))) pfun, axis_name = _papply(fun) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = soft_pmap(pfun, axis_name)(onp.arange(1., 5.)) expected = fun(onp.arange(1., 5.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testBadAxisSizeError(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i', devices=xla_bridge.devices()) with self.assertRaisesRegex( ValueError, r"Leading axis size of input to pmapped function must " r"equal the number of local devices passed to pmap. Got axis_size=1, " r"num_local_devices=\d."): f(np.ones(1)) with self.assertRaisesRegex( ValueError, r"Leading axis size of input to pmapped function must " r"equal the number of local devices passed to pmap. Got axis_size=\d, " r"num_local_devices=\d."): f(np.ones(xla_bridge.device_count() + 1))
def _eval_model( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: logits, _ = self.model_fn(params, batch, model_state, spec.ForwardPassMode.EVAL, rng, update_batch_norm=False) accuracy = jnp.sum(jnp.argmax(logits, axis=-1) == batch['targets']) loss = jnp.sum(self.loss_fn(batch['targets'], logits)) num_data = len(logits) metrics = {'accuracy': accuracy, 'loss': loss, 'num_data': num_data} metrics = lax.psum(metrics, axis_name='batch') return metrics
def testPsumMultiple(self): f = lambda x: lax.psum(x, ('i', 'j')) f = pmap(pmap(f, 'i'), 'j') def sum_and_broadcast(x, axis): return onp.repeat(onp.sum(x, axis, keepdims=True), x.shape[axis], axis) device_count = xla_bridge.device_count() num_pairs, ragged = divmod(device_count, 2) if num_pairs > 1 and not ragged: shape = (num_pairs, 2, 4) else: shape = (device_count, 1, 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) ans = f(x) expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1) self.assertAllClose(ans, expected, check_dtypes=False)
def testPartiallyMapped(self): f = pmap(lambda x, y: x, in_axes=(None, 0)) g = pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0)) mesh_shape = (xla_bridge.device_count(),) shape = mesh_shape + (4,) x = onp.array(3., dtype=onp.float32) y = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) f_expected = onp.broadcast_to(x, mesh_shape) f_ans = f(x, y) self.assertAllClose(f_ans, f_expected, check_dtypes=True) self.assertIsInstance(f_ans, pxla.ShardedDeviceArray) # the output is actually replicated (has the same values in each device buffer) # but out_axes is implicitly 0, so we shouldn't have replication in the # sharding spec. self.assertEqual(f_ans.sharding_spec.replication_factor, 1) g_expected = onp.broadcast_to(x - onp.sum(y, 0, keepdims=True), shape) g_ans = g(x, y) self.assertAllClose(g_ans, g_expected, check_dtypes=True) self.assertIsInstance(g_ans, pxla.ShardedDeviceArray) self.assertEqual(g_ans.sharding_spec.replication_factor, 1)
def _one_hot(x: Array, num_classes: int, *, dtype: Any, axis: Union[int, AxisName]) -> Array: num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError( f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) # type: ignore[arg-type] lhs = lax.expand_dims(x, (axis, )) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis) return jnp.asarray(lhs == rhs, dtype=dtype)
def testIssue804(self): num_devices = xla_bridge.device_count() f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i"), c), 0.) api.pmap(f, axis_name="i")(np.ones((num_devices, 4))) # doesn't crash
def get_axis_size(axis_name=None): if JAX_MODE: return lax.psum(1, axis_name) ctx = tf.distribute.get_replica_context() return ctx.num_replicas_in_sync
def _sum_seeds_pmapped(seed): return lax.psum(seed, 'hosts')