def _test_equivalence_to_reference_code(self, model_cls, inp, input_signature, common_kwargs, *test_kwargs): ref_model = model_cls(use_reference_code=True, **common_kwargs) rng = math.random.get_prng(123) weights, state = ref_model.init(input_signature, rng) ref_all = self._run_forward_and_backward(ref_model, inp, weights, state) ref_out, ref_state, ref_inp_grad, ref_weights_grad = ref_all for kwargs in test_kwargs: test_model = model_cls(**common_kwargs, **kwargs) state = test_model.init(input_signature, rng)[1] test_all = self._run_forward_and_backward(test_model, inp, weights, state) test_out, test_state, test_inp_grad, test_weights_grad = test_all self.assertEqual(jax.tree_structure(ref_out), jax.tree_structure(test_out)) self.assertEqual(jax.tree_structure(ref_state), jax.tree_structure(test_state)) self.assertEqual(jax.tree_structure(ref_inp_grad), jax.tree_structure(test_inp_grad)) self.assertEqual(jax.tree_structure(ref_weights_grad), jax.tree_structure(test_weights_grad)) check_close = lambda x, y: self.assertAllClose( x, y, rtol=1e-3, atol=1e-3) math.nested_map_multiarg(check_close, ref_out, test_out) math.nested_map_multiarg(check_close, ref_state, test_state) math.nested_map_multiarg(check_close, ref_inp_grad, test_inp_grad) math.nested_map_multiarg(check_close, ref_weights_grad, test_weights_grad)
def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits, new_state = model.pure_fn(x, weights, state, rng) loss = math.numpy.mean(logits[..., 0]) return loss, (new_state, logits) gradients, (new_state, logits) = math.grad(compute_mock_loss, has_aux=True)(weights) new_weights = math.nested_map_multiarg(lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self.compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): res, _ = self.compute_residual.pure_fn(x, weights=weights, state=state[0], rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self.attention_layer is not None: n_differentiable = min(len(res), self.attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = _inputs_from_stack(self.compute_residual, stack) outputs, compute_residual_vjpfun, outputs_aux = jax.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = _outputs_onto_stack(self.compute_residual, outputs, stack) stack_ct = accumulator_output_ct if self.attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = _inputs_from_stack(self.attention_layer, stack) (residual, _, attn_inputs_ct, attn_weights_ct) = self.attention_layer.forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = _outputs_onto_stack(self.attention_layer, attn_inputs_ct, stack_ct, self.attention_layer.n_out, self.attention_layer.n_in) compute_residual_ct = _inputs_from_stack(self.compute_residual, stack_ct, self.compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct, ) compute_residual_ct = compute_residual_ct[: n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = _outputs_onto_stack(self.compute_residual, compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out, self.compute_residual.n_in) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct, ) stack_ct = (accumulator_output_ct, ) + math.nested_map_multiarg( lambda x, y: x + y, context_ct[:len(stack_ct)], stack_ct) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x, ) + context if self.attention_layer is None: weights_ct = (compute_residual_weights_ct, ) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)