Esempio n. 1
0
    def _test_equivalence_to_reference_code(self, model_cls, inp,
                                            input_signature, common_kwargs,
                                            *test_kwargs):
        ref_model = model_cls(use_reference_code=True, **common_kwargs)
        rng = math.random.get_prng(123)
        weights, state = ref_model.init(input_signature, rng)

        ref_all = self._run_forward_and_backward(ref_model, inp, weights,
                                                 state)
        ref_out, ref_state, ref_inp_grad, ref_weights_grad = ref_all

        for kwargs in test_kwargs:
            test_model = model_cls(**common_kwargs, **kwargs)
            state = test_model.init(input_signature, rng)[1]
            test_all = self._run_forward_and_backward(test_model, inp, weights,
                                                      state)
            test_out, test_state, test_inp_grad, test_weights_grad = test_all

            self.assertEqual(jax.tree_structure(ref_out),
                             jax.tree_structure(test_out))
            self.assertEqual(jax.tree_structure(ref_state),
                             jax.tree_structure(test_state))
            self.assertEqual(jax.tree_structure(ref_inp_grad),
                             jax.tree_structure(test_inp_grad))
            self.assertEqual(jax.tree_structure(ref_weights_grad),
                             jax.tree_structure(test_weights_grad))

            check_close = lambda x, y: self.assertAllClose(
                x, y, rtol=1e-3, atol=1e-3)
            math.nested_map_multiarg(check_close, ref_out, test_out)
            math.nested_map_multiarg(check_close, ref_state, test_state)
            math.nested_map_multiarg(check_close, ref_inp_grad, test_inp_grad)
            math.nested_map_multiarg(check_close, ref_weights_grad,
                                     test_weights_grad)
Esempio n. 2
0
        def mock_training_step(x, weights, state, rng):
            def compute_mock_loss(weights):
                logits, new_state = model.pure_fn(x, weights, state, rng)
                loss = math.numpy.mean(logits[..., 0])
                return loss, (new_state, logits)

            gradients, (new_state, logits) = math.grad(compute_mock_loss,
                                                       has_aux=True)(weights)
            new_weights = math.nested_map_multiarg(lambda w, g: w - 1e-4 * g,
                                                   weights, gradients)
            return new_weights, new_state, logits
Esempio n. 3
0
    def reverse_and_grad(self,
                         output,
                         ct,
                         weights=(),
                         state=(),
                         new_state=(),
                         rng=None):
        rngs = _split_rngs(rng, len(self.sublayers))

        accumulator_output, *context = output
        context = tuple(context)
        accumulator_output_ct, *context_ct = ct
        context_ct = tuple(context_ct)

        # Forward pass through self.compute_residual. Outputs that will not receive
        # a gradient signal from subsequent layers are moved to aux.
        def call_compute_residual(x, weights):
            res, _ = self.compute_residual.pure_fn(x,
                                                   weights=weights,
                                                   state=state[0],
                                                   rng=rngs[0])
            if not isinstance(res, (tuple, list)):
                return res, None
            else:
                n_differentiable = 1
                if self.attention_layer is not None:
                    n_differentiable = min(len(res), self.attention_layer.n_in)
                return res[:n_differentiable], res[n_differentiable:]

        stack = context
        inputs = _inputs_from_stack(self.compute_residual, stack)
        outputs, compute_residual_vjpfun, outputs_aux = jax.vjp(
            call_compute_residual, inputs, weights[0], has_aux=True)
        if outputs_aux is not None:
            n_differentiable_outputs = len(outputs)
            outputs = outputs + outputs_aux
        stack = _outputs_onto_stack(self.compute_residual, outputs, stack)

        stack_ct = accumulator_output_ct
        if self.attention_layer is None:
            residual = stack[0] if isinstance(stack, (tuple, list)) else stack
        else:
            inputs = _inputs_from_stack(self.attention_layer, stack)
            (residual, _, attn_inputs_ct,
             attn_weights_ct) = self.attention_layer.forward_and_or_backward(
                 inputs,
                 weights[1],
                 new_state[1],
                 rngs[1],
                 output_grad=accumulator_output_ct,
                 compute_output=True,
                 update_state=False)
            stack_ct = _outputs_onto_stack(self.attention_layer,
                                           attn_inputs_ct, stack_ct,
                                           self.attention_layer.n_out,
                                           self.attention_layer.n_in)

        compute_residual_ct = _inputs_from_stack(self.compute_residual,
                                                 stack_ct,
                                                 self.compute_residual.n_out)
        if outputs_aux is not None:
            if not isinstance(compute_residual_ct, (tuple, list)):
                compute_residual_ct = (compute_residual_ct, )
            compute_residual_ct = compute_residual_ct[:
                                                      n_differentiable_outputs]
            assert len(compute_residual_ct) == n_differentiable_outputs
        (compute_residual_inputs_ct, compute_residual_weights_ct
         ) = compute_residual_vjpfun(compute_residual_ct)
        stack_ct = _outputs_onto_stack(self.compute_residual,
                                       compute_residual_inputs_ct, stack_ct,
                                       self.compute_residual.n_out,
                                       self.compute_residual.n_in)
        if not isinstance(stack_ct, (tuple, list)):
            stack_ct = (stack_ct, )
        stack_ct = (accumulator_output_ct, ) + math.nested_map_multiarg(
            lambda x, y: x + y, context_ct[:len(stack_ct)],
            stack_ct) + context_ct[len(stack_ct):]

        reconstructed_x = accumulator_output - residual
        stack = (reconstructed_x, ) + context
        if self.attention_layer is None:
            weights_ct = (compute_residual_weights_ct, )
        else:
            weights_ct = (compute_residual_weights_ct, attn_weights_ct)
        return stack, (stack_ct, weights_ct)