Exemple #1
0
 def single_update(i, opt_state, batch, state, rng):
     weights, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng[0])
     grad_fn = math.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(weights, batch, state, rng)
     return optimizer.tree_update(i, grads, weights, slots,
                                  opt_params), state, [subrng]
Exemple #2
0
    def test_custom_zero_grad(self):
        class IdWithZeroGrad(base.Layer):
            def forward(self, x, weights):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, grad, weights, state, new_state,
                         rng):
                return (jnp.zeros_like(grad), ())

        layer = IdWithZeroGrad()
        rng = math.random.get_prng(0)
        input_signature = shapes.ShapeDtype((9, 17))
        random_input = math.random.uniform(rng,
                                           input_signature.shape,
                                           minval=-1.0,
                                           maxval=1.0)
        layer.init(input_signature)
        f = lambda x: jnp.mean(layer(x))
        grad = math.grad(f)(random_input)
        self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
        self.assertEqual(sum(sum(grad * grad)), 0.0)  # Each one is 0.
Exemple #3
0
    def test_custom_id_grad(self):
        class IdWithIdGrad(base.Layer):
            def forward(self, x, weights):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, ct, weights, state, new_state,
                         **kwargs):
                return (inputs, ())

        layer = IdWithIdGrad()
        rng = math.random.get_prng(0)
        input_signature = ShapeDtype((9, 17))
        random_input = math.random.uniform(rng,
                                           input_signature.shape,
                                           minval=-1.0,
                                           maxval=1.0)
        layer.init(input_signature)
        f = lambda x: np.mean(layer(x))
        grad = math.grad(f)(random_input)
        self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
        self.assertEqual(sum(sum(grad)),
                         sum(sum(random_input)))  # Same as input.
Exemple #4
0
    def test_reformer_rng_consistency(self):
        with math.use_backend('jax'):
            vocab_size = 16
            batch_size = 1
            input_sd = ShapeDtype((batch_size, 8), np.int32)
            input_signature = (input_sd, input_sd)
            model = reformer.ReformerLM(
                vocab_size,
                d_model=32,
                d_ff=64,
                d_attention_key=16,
                d_attention_value=16,
                n_layers=1,
                n_heads=2,
                max_len=16,
                n_chunks=2,
                n_attention_chunks=1,
                mode='train',
                attention_type=PoisonOnRNGMismatchAttention)

            rng = math.random.get_prng(0)
            weights, state = model.init(input_signature)

            def dummy_loss_fn(weights):
                inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2
                output = model(inputs, weights=weights, state=state, rng=rng)
                dummy_loss = math.numpy.sum(output[0])
                return dummy_loss

            grad_fn = math.grad(dummy_loss_fn)
            grads = grad_fn(weights)
            # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
            for grad in jax.tree_util.tree_leaves(grads):
                assert onp.all(onp.isfinite(grad))
Exemple #5
0
        def mock_training_step(x, weights, state, rng):
            def compute_mock_loss(weights):
                logits, new_state = model.pure_fn(x, weights, state, rng)
                loss = math.numpy.mean(logits[..., 0])
                return loss, (new_state, logits)

            gradients, (new_state, logits) = math.grad(compute_mock_loss,
                                                       has_aux=True)(weights)
            new_weights = math.nested_map_multiarg(lambda w, g: w - 1e-4 * g,
                                                   weights, gradients)
            return new_weights, new_state, logits
Exemple #6
0
 def _run_one_step(self):
     """Updates model weights and optimizer slots by running one step/batch."""
     optimizer = self._task.optimizer
     # TODO(jonni): figure out why JAX tracer needs the following line.
     weights = self._model.weights
     opt_params = optimizer._init_opt_params  # pylint: disable=protected-access
     batch = self._task.next_batch()
     model_with_loss = tl.Serial(self._model, self._task.loss_layer)
     loss_as_fn_of_weights = lambda w: model_with_loss(batch, weights=w)
     gradients = math.grad(loss_as_fn_of_weights)(model_with_loss.weights)
     self._model.weights, optimizer.slots = optimizer.tree_update(
         self.current_step(), gradients, weights, optimizer.slots,
         opt_params)
Exemple #7
0
    def __init__(self,
                 model,
                 task,
                 eval_task=None,
                 output_dir=None,
                 checkpoint_at=None):
        """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      task: TrainTask instance, which defines the training data, loss function,
          and optimizer to be used in this training loop.
      eval_task: EvalTask instance or None. If None, don't do any evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, don't save any
          checkpoints.
    """
        self._task = task
        self._model_in_training = tl.Serial(model, task.loss_layer)
        self._eval_task = eval_task
        self._output_dir = output_dir
        self._checkpoint_at = checkpoint_at or _never
        self._step = None

        batch_signature = shapes.signature(task.sample_batch)
        # Initialize the model and the optimizer; discard the return values
        # (model weights/state, optimizer slots/params), since they're available
        # from the model and optimizer objects.
        _, _ = self._model_in_training.init(batch_signature)
        _, _ = task.optimizer.tree_init(self._model_in_training.weights)

        self._gradients_and_state_fn = (
            math.jit(
                math.grad(
                    self._model_in_training.pure_fn,
                    argnums=1,  # arg1 of pure_fn: weights
                    has_aux=True)))  # return (gradients, state)

        if eval_task is not None:
            model_with_metrics = _model_with_metrics(model, eval_task)
            self._eval_weights = model_with_metrics.weights[
                1]  # just the eval part
            self._eval_state = model_with_metrics.state[
                1]  # just the eval part
            self._metrics_fn = math.jit(model_with_metrics.pure_fn)
Exemple #8
0
 def mapped_update(i, opt_state, batch, state, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     weights, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng)
     grad_fn = math.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(weights, batch, state, rng)
     # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
     # the number of devices on this host machine, however psum goes over all
     # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
     # of them.
     grads = jax.tree_util.tree_map(
         lambda g: math.psum(g, 'batch') / math.psum(1.0, 'batch'), grads)
     return optimizer.tree_update(i, grads, weights, slots,
                                  opt_params), state, subrng