def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" gradients_psum = fastmath.psum(gradients, 'batch') # sum over all devices n = fastmath.psum(jnp.array(1.0), 'batch') # number of devices on all hosts if not adasum: return fastmath.nested_map(lambda g: g / n, gradients_psum) # This implements an approximation of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf # Since implementing halving and averaging half-by-half is tricky, we first # average all hosts, so we use the sum as a point of comparison for gradients. # So for 2 devices, this algorithm is the same as in the paper, but with more # devices it does a different kind of averaging. It still has the property # that orthogonal gradients will result in a sum while identical ones will # be averaged, as postulated in the paper. adasum_nominator = fastmath.nested_map_multiarg( lambda g, q: jnp.vdot(g, q), # pylint: disable=unnecessary-lambda gradients, gradients_psum) grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients) # If all devices have identical gradients, then the nominator is equal # to n * grad_norm; if they're orthogonal, then nominator = grad_norm. scaled_grads = fastmath.nested_map_multiarg( lambda g, nominator, g_norm: g * (1 - (nominator - g_norm) / (n * g_norm)), gradients, adasum_nominator, grad_norm) return fastmath.psum(scaled_grads, 'batch')
def _assert_all_equal(self, t1, t2, tol=1e-5): def eq(x1, x2): diff = np.maximum(np.abs(x1 - x2) - tol, 0.0) self.assertLessEqual(np.sum(diff), 0.0, msg=f'\n{x1}\n !=\n{x2}\n diff:\n{x1-x2}') fastmath.nested_map_multiarg(eq, t1, t2)
def _test_equivalence_to_reference_code(self, model_cls, inp, input_signature, common_kwargs, *test_kwargs): ref_model = model_cls(use_reference_code=True, **common_kwargs) rng = fastmath.random.get_prng(123) weights, state = ref_model.init(input_signature, rng) ref_all = self._run_forward_and_backward(ref_model, inp, weights, state) ref_out, ref_state, ref_inp_grad, ref_weights_grad = ref_all for kwargs in test_kwargs: test_model = model_cls(**common_kwargs, **kwargs) state = test_model.init(input_signature, rng)[1] test_all = self._run_forward_and_backward(test_model, inp, weights, state) test_out, test_state, test_inp_grad, test_weights_grad = test_all self.assertEqual(jax.tree_structure(ref_out), jax.tree_structure(test_out)) self.assertEqual(jax.tree_structure(ref_state), jax.tree_structure(test_state)) self.assertEqual(jax.tree_structure(ref_inp_grad), jax.tree_structure(test_inp_grad)) self.assertEqual(jax.tree_structure(ref_weights_grad), jax.tree_structure(test_weights_grad)) check_close = lambda x, y: self.assertAllClose( x, y, rtol=1e-3, atol=1e-3) fastmath.nested_map_multiarg(check_close, ref_out, test_out) fastmath.nested_map_multiarg(check_close, ref_state, test_state) fastmath.nested_map_multiarg(check_close, ref_inp_grad, test_inp_grad) fastmath.nested_map_multiarg(check_close, ref_weights_grad, test_weights_grad)
def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS if adasum: # This implements a version of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf lg = max([i for i in range(20) if 2**i <= n]) for lg_i in range(lg): shift = 2**lg_i perm = [] for i in range(n): block_i = i % (2 * shift) # we do blocks of 2*shift size if block_i < shift: perm.append((i, i + shift)) else: perm.append((i, i - shift)) perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name='batch') gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients, perm_grad) if base.N_WEIGHTS_SHARDS > 1: # only sum gradients from matching shards groups = [[base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))] for d in range(base.N_WEIGHTS_SHARDS)] gradients_psum = fastmath.psum(gradients, 'batch', axis_index_groups=groups) else: gradients_psum = fastmath.psum(gradients, 'batch') # sum all gradients n = jnp.array(n, dtype=jnp.float32) return fastmath.nested_map(lambda g: g / n, gradients_psum)
def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits, new_state = model.pure_fn(x, weights, state, rng) loss = fastmath.numpy.mean(logits[..., 0]) return loss, (new_state, logits) gradients, (new_state, logits) = fastmath.grad( compute_mock_loss, has_aux=True)(weights) new_weights = fastmath.nested_map_multiarg( lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits
def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" n = jnp.array(fastmath.global_device_count(), dtype=jnp.float32) if adasum: # This implements a version of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf lg = max([i for i in range(20) if 2**i <= n]) for lg_i in range(lg): shift = 2**lg_i perm = [] for i in range(n): block_i = i % (2 * shift) # we do blocks of 2*shift size if block_i < shift: perm.append((i, i + shift)) else: perm.append((i, i - shift)) perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name='batch') gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients, perm_grad) gradients_psum = fastmath.psum(gradients, 'batch') # sum over all devices return fastmath.nested_map(lambda g: g / n, gradients_psum)
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self.compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): res, _ = self.compute_residual.pure_fn( x, weights=weights, state=state[0], rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self.attention_layer is not None: n_differentiable = min(len(res), self.attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = cb.inputs_from_stack(stack, self.compute_residual.n_in) outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = cb.outputs_onto_stack(outputs, stack, self.compute_residual.n_in) stack_ct = accumulator_output_ct if self.attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = cb.inputs_from_stack(stack, self.attention_layer.n_in) (residual, _, attn_inputs_ct, attn_weights_ct ) = self._forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = cb.outputs_onto_stack( attn_inputs_ct, stack_ct, self.attention_layer.n_out) compute_residual_ct = cb.inputs_from_stack( stack_ct, self.compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct,) compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = cb.outputs_onto_stack( compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct,) stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg( lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct ) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x,) + context if self.attention_layer is None: weights_ct = (compute_residual_weights_ct,) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self._compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): state_to_pass = state[0] # old_state # _replace_second_time is currently used exclusively in _RememberInReverse # layer to combat numerical instability in Reformer2 when quantizing # the mask in SparseFF. def _replace_second_time(stt, nstt): if (isinstance(stt, tuple) and len(stt) == 2 and isinstance(stt[1], dict) and 'running_second_time' in stt[1]): return (nstt[0], {'running_second_time_yes': ()}) elif isinstance(stt, (tuple, list)): assert isinstance(nstt, (tuple, list)) and len(nstt) == len(stt) return type(stt)([ _replace_second_time(s, ns) for s, ns in zip(stt, nstt) ]) else: return stt state_to_pass = _replace_second_time(state_to_pass, new_state[0]) res, _ = self._compute_residual.pure_fn(x, weights=weights, state=state_to_pass, rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self._attention_layer is not None: n_differentiable = min(len(res), self._attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) stack_ct = accumulator_output_ct if self._attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) (residual, _, attn_inputs_ct, attn_weights_ct) = self._forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = cb.outputs_onto_stack(attn_inputs_ct, stack_ct, self._attention_layer.n_out) compute_residual_ct = cb.inputs_from_stack( stack_ct, self._compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct, ) compute_residual_ct = compute_residual_ct[: n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = cb.outputs_onto_stack(compute_residual_inputs_ct, stack_ct, self._compute_residual.n_out) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct, ) def _add(x, y): # `None` is for TFNP backend, which uses `None` as the gradient of # int/bool instead of an array of dtype `float0`. if x is None or x.dtype == jax.float0: return y if y is None or y.dtype == jax.float0: return x return x + y stack_ct = (accumulator_output_ct, ) + fastmath.nested_map_multiarg( _add, context_ct[:len(stack_ct)], stack_ct) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x, ) + context if self._attention_layer is None: weights_ct = (compute_residual_weights_ct, ) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)