def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS if adasum: # This implements a version of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf lg = max([i for i in range(20) if 2**i <= n]) for lg_i in range(lg): shift = 2**lg_i perm = [] for i in range(n): block_i = i % (2 * shift) # we do blocks of 2*shift size if block_i < shift: perm.append((i, i + shift)) else: perm.append((i, i - shift)) perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name='batch') gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients, perm_grad) if base.N_WEIGHTS_SHARDS > 1: # only sum gradients from matching shards groups = [[base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))] for d in range(base.N_WEIGHTS_SHARDS)] gradients_psum = fastmath.psum(gradients, 'batch', axis_index_groups=groups) else: gradients_psum = fastmath.psum(gradients, 'batch') # sum all gradients n = jnp.array(n, dtype=jnp.float32) return fastmath.nested_map(lambda g: g / n, gradients_psum)
def _n_weights_per_core(weights): # pylint: disable=invalid-name """Calculates the number of weights per core. In multi-device settings, gradients and losses are averaged over all devices. When loss is weighted and the number of weights can differ by device, e.g., when the weights represent the number of tokens in a batch of sentences (which can differ from device to device), we want to make sure each token on each device is weighted in the same way. This function ensures that by reporting the number of weights per core in multi-core settings (and simply np.sum(weights) in a single-core setting). Args: weights: tensor with arbitrary shape Returns: a scalar equal to np.sum(weights) in 1-machine settings and to the sum of weights over all cores divided by the number of cores otherwise """ weights_sum = jnp.sum(weights) if fastmath.device_count() < 2: return weights_sum else: try: n_devices_total = fastmath.psum(jnp.array(1.0), 'batch') return fastmath.psum(weights_sum, 'batch') / n_devices_total except (NameError, ValueError): # running outside of pmap, e.g., on init return weights_sum # fall back to the sum
def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" gradients_psum = fastmath.psum(gradients, 'batch') # sum over all devices n = fastmath.psum(jnp.array(1.0), 'batch') # number of devices on all hosts if not adasum: return fastmath.nested_map(lambda g: g / n, gradients_psum) # This implements an approximation of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf # Since implementing halving and averaging half-by-half is tricky, we first # average all hosts, so we use the sum as a point of comparison for gradients. # So for 2 devices, this algorithm is the same as in the paper, but with more # devices it does a different kind of averaging. It still has the property # that orthogonal gradients will result in a sum while identical ones will # be averaged, as postulated in the paper. adasum_nominator = fastmath.nested_map_multiarg( lambda g, q: jnp.vdot(g, q), # pylint: disable=unnecessary-lambda gradients, gradients_psum) grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients) # If all devices have identical gradients, then the nominator is equal # to n * grad_norm; if they're orthogonal, then nominator = grad_norm. scaled_grads = fastmath.nested_map_multiarg( lambda g, nominator, g_norm: g * (1 - (nominator - g_norm) / (n * g_norm)), gradients, adasum_nominator, grad_norm) return fastmath.psum(scaled_grads, 'batch')
def _average_multidevice_gradients(gradients): """Averages gradients over all the devices across different hosts.""" # Sum gradients over all devices across all hosts. gradients = fastmath.psum(gradients, 'batch') # Calculate the total number of devices. # Note: the usual n_devices is only the number of devices at this host, # here we are calculating the number of all devices across all hosts. n_devices_total = fastmath.psum(jnp.array(1.0), 'batch') # Average across hosts. return fastmath.nested_map(lambda g: g / n_devices_total, gradients)
def _multi_device_update_fn( weights_and_slots, step, opt_params, batch, state, rng): # We assume all tensors have the first dimension = n_devices. weights, slots = weights_and_slots (loss, state), gradients = forward_and_backward_fn( batch, weights, state, rng) # gradients now need to be summed over all the devices across different host # machines, n_devices is only the number of devices on *this* host machine. gradients = fastmath.psum(gradients, 'batch') n_devices_total = fastmath.psum(jnp.array(1.0), 'batch') # Average across hosts. gradients = jax.tree_util.tree_map(lambda g: g / n_devices_total, gradients) weights, slots, stats = optimizer.tree_update( step, gradients, weights, slots, opt_params) stats['loss'] = loss return (weights, slots), state, stats
def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots = weights_and_slots rng, subrng = jax_random.split(rng) grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: ( # pylint: disable=g-long-lambda fastmath.psum(g, 'batch') / fastmath.psum(np.array(1.0), 'batch')), grads) new_weights, new_slots, stats = optimizer.tree_update( i, grads, weights, slots, opt_params) return (new_weights, new_slots), stats, state, subrng
def _make_weights_and_state_same_across_hosts(weights_and_state): """Makes train and eval model's weights and state the same across hosts.""" # We assume that they have been already replicated, i.e the leading axis is # self._n_devices # This is the total number of devices across all hosts. n_devices_total = fastmath.psum(jnp.array(1.0), 'devices') # This sums up the weights and state across all devices. # NOTE: There will not be any leading axis remaining because we psum # over it. weights_and_state = fastmath.psum(weights_and_state, 'devices') # We finally take the average over all devices. weights_and_state = jax.tree_util.tree_map( lambda ws: ws / n_devices_total, weights_and_state) return weights_and_state
def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" n = jnp.array(fastmath.global_device_count(), dtype=jnp.float32) if adasum: # This implements a version of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf lg = max([i for i in range(20) if 2**i <= n]) for lg_i in range(lg): shift = 2**lg_i perm = [] for i in range(n): block_i = i % (2 * shift) # we do blocks of 2*shift size if block_i < shift: perm.append((i, i + shift)) else: perm.append((i, i - shift)) perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name='batch') gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients, perm_grad) gradients_psum = fastmath.psum(gradients, 'batch') # sum over all devices return fastmath.nested_map(lambda g: g / n, gradients_psum)
def _get_conditionally_synced_rng(self): if self._sync and fastmath.global_device_count() > 1: return fastmath.psum(self.rng, 'batch') else: return self.rng
def forward(self, xs): self._validate_forward_inputs(xs) (step, layers_state) = self.state # Get N+1 rngs, N for running layers and one extra. rngs = _split_rngs(self.rng, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. self.state = (step + 1, layers_state) return xs # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers weights = self.weights if n_layers != 1 and len(weights) != n_layers: raise ValueError('number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(layers_state) != n_layers: raise ValueError('length of state ({}) not equal to number of layers ' '({})'.format(len(layers_state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using fastmath.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after w_steps = float(self._skipping_warmup_steps) f_step = step.astype(jnp.float32) warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps) # low is the minimum number of layers to *not* skip, from n_layers to 0 low = warmup * float(n_layers) # high should be so that (high - n_layers) / high = 1.0 - skip_fraction # because (high - n_layers) / high is the probability we're not skipping # (after warmup); so high - n_layers = high - high * skip_fraction high = float(n_layers) / self._skip_fraction # We want the same rng0 on all cores. if fastmath.device_count() > 1: rng0 = fastmath.psum(rng0, 'batch') n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs): inputs = _inputs_from_stack(layer, stack) def CondF(t): return layer.pure_fn(t[0], t[1], t[2], t[3]) # pylint: disable=cell-var-from-loop def PassF(t): return t[0], t[2] outputs, s = fastmath.cond( fastmath.lt(cur_layer_idx, n_forward_layers), CondF, PassF, (inputs, p, s, rng) ) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 self.state = (step + 1, new_state) return stack
def _average_multidevice_gradients(gradients): """Averages gradients over all the devices across different hosts.""" gradients_psum = fastmath.psum(gradients, 'batch') # sum over all devices n_devices_total = fastmath.psum(jnp.array(1.0), 'batch') return fastmath.nested_map(lambda g: g / n_devices_total, gradients_psum)