Пример #1
0
 def test_call_and_grad(self):
     layer_partial = tl.Serial(
         tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()),
         sparsity.Favor(d_feature=4, n_heads=2),
         tl.Select([0], n_in=2),
     )
     layer = tl.Serial(
         tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()),
         sparsity.Favor(d_feature=4, n_heads=2),
         tl.Select([0], n_in=2),
         tl.WeightedCategoryCrossEntropy(),
     )
     x = np.ones((1, 2), dtype=np.int32)
     w = np.ones_like(x).astype(np.float32)
     x_sig = shapes.signature(x)
     w_sig = shapes.signature(w)
     layer_partial.init(x_sig)
     y = layer_partial(x)
     self.assertEqual(y.shape, (1, 2, 4))
     layer.init((x_sig, x_sig, w_sig))
     y = layer((x, x, w))
     self.assertEqual(y.shape, ())
     state = layer.state
     rng = fastmath.random.get_prng(0)
     fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[
         0]
     g = fastmath.grad(fwd)(layer.weights, (x, x, w))
     self.assertEqual(g[0][1][0].shape, (3, 4))
Пример #2
0
    def test_custom_zero_grad(self, backend):
        class IdWithZeroGrad(tl.Layer):
            def forward(self, x):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, grad, weights, state, new_state,
                         rng):
                return (jnp.zeros_like(grad), ())

        with fastmath.use_backend(backend):
            layer = IdWithZeroGrad()
            rng = fastmath.random.get_prng(0)
            input_signature = shapes.ShapeDtype((9, 17))
            random_input = fastmath.random.uniform(rng,
                                                   input_signature.shape,
                                                   minval=-1.0,
                                                   maxval=1.0)
            layer.init(input_signature)
            f = lambda x: jnp.mean(layer(x))
            grad = fastmath.grad(f)(random_input)
            self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
            self.assertEqual(sum(sum(grad * grad)), 0.0)  # Each one is 0.
Пример #3
0
 def single_update(weights_and_slots, i, opt_params, batch, state, rng):
     weights, slots = weights_and_slots
     rng, subrng = jax_random.split(rng[0])
     grad_fn = fastmath.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(weights, batch, state, rng)
     new_weights, new_slots, stats = optimizer.tree_update(
         i, grads, weights, slots, opt_params)
     return (new_weights, new_slots), stats, state, [subrng]
Пример #4
0
 def mock_training_step(x, weights, state, rng):
   def compute_mock_loss(weights):
     logits, new_state = model.pure_fn(x, weights, state, rng)
     loss = fastmath.numpy.mean(logits[..., 0])
     return loss, (new_state, logits)
   gradients, (new_state, logits) = fastmath.grad(
       compute_mock_loss, has_aux=True)(weights)
   new_weights = fastmath.nested_map_multiarg(
       lambda w, g: w - 1e-4 * g, weights, gradients)
   return new_weights, new_state, logits
Пример #5
0
  def __init__(self, model, task, eval_model=None, eval_task=None,
               output_dir=None, checkpoint_at=None, eval_at=None):
    """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      task: TrainTask instance, which defines the training data, loss function,
          and optimizer to be used in this training loop.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_task: EvalTask instance or None. If None, don't do any evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
    """
    self._task = task
    self._model = model
    self._model_in_training = tl.Serial(model, task.loss_layer)
    self._eval_model = model if eval_model is None else eval_model
    self._eval_task = eval_task
    self._output_dir = os.path.expanduser(output_dir) if output_dir else None
    default_fn = _at_step_1_and_periodically_at(task.n_steps_per_checkpoint)
    self._checkpoint_at = checkpoint_at or default_fn
    self._eval_at = eval_at or default_fn
    if eval_task is None:
      self._eval_at = _never
    self._step = 0

    batch_signature = shapes.signature(task.sample_batch)
    self._batch_signature = batch_signature
    # Initialize the model and the optimizer; discard the return values
    # (model weights/state, optimizer slots/params), since they're available
    # from the model and optimizer objects.
    _, _ = self._model_in_training.init(batch_signature)
    _, _ = task.optimizer.tree_init(self._model_in_training.weights)

    self._gradients_and_state_fn = (
        fastmath.jit(fastmath.grad(self._model_in_training.pure_fn,
                                   argnums=1,  # arg1 of pure_fn: weights
                                   has_aux=True)))  # return (gradients, state)

    if eval_task is not None:
      model_with_metrics = _model_with_metrics(self._eval_model, eval_task)
      self._eval_weights = model_with_metrics.weights[1]  # just the eval part
      self._eval_state = model_with_metrics.state[1]  # just the eval part
      self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)
Пример #6
0
 def test_causal_call_and_grad(self):
     layer = tl.Serial(tl.Dense(4),
                       sparsity.CausalFavor(d_feature=4, n_heads=2),
                       tl.L2Loss())
     x = np.random.uniform(size=(1, 2, 4)).astype(np.float32)
     w = np.ones_like(x)
     x_sig = shapes.signature(x)
     w_sig = shapes.signature(w)
     layer.init((x_sig, x_sig, w_sig))
     y = layer((x, x, w))
     self.assertEqual(y.shape, ())
     state = layer.state
     rng = fastmath.random.get_prng(0)
     fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[
         0]
     g = fastmath.grad(fwd)(layer.weights, (x, x, w))
     self.assertEqual(g[0][0].shape, (4, 4))
Пример #7
0
 def mapped_update(weights_and_slots, i, opt_params, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots = weights_and_slots
   rng, subrng = jax_random.split(rng)
   grad_fn = fastmath.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
   # the number of devices on this host machine, however psum goes over all
   # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
   # of them.
   grads = jax.tree_util.tree_map(
       lambda g: (  # pylint: disable=g-long-lambda
           fastmath.psum(g, 'batch') / fastmath.psum(np.array(1.0), 'batch')),
       grads)
   new_weights, new_slots, stats = optimizer.tree_update(
       i, grads, weights, slots, opt_params)
   return (new_weights, new_slots), stats, state, subrng
Пример #8
0
    def test_custom_id_grad(self, backend):
        # After changes to some fastmath.custom_vjp functions (made so that we could
        # land JAX PR #4008), this test started failing, with a ValueError from
        # TensorFlow: ValueError: ('custom_gradient function expected to return', 2,
        # 'gradients but returned', 3, 'instead.'.
        # TODO(mattjj,lukaszkaiser): revive this test after landing #4008
        if backend == fastmath.Backend.TFNP:
            raise unittest.SkipTest(
                'temporarily skipping test so that we can '
                'land https://github.com/google/jax/pull/4008')

        class IdWithIdGrad(tl.Layer):
            def forward(self, x):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, grad, weights, state, new_state,
                         rng):
                return (inputs, ())

        with fastmath.use_backend(backend):
            layer = IdWithIdGrad()
            rng = fastmath.random.get_prng(0)
            input_signature = shapes.ShapeDtype((9, 17))
            random_input = fastmath.random.uniform(rng,
                                                   input_signature.shape,
                                                   minval=-1.0,
                                                   maxval=1.0)
            layer.init(input_signature)
            f = lambda x: jnp.mean(layer(x))
            grad = fastmath.grad(f)(random_input)
            self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
            self.assertEqual(sum(sum(grad)),
                             sum(sum(random_input)))  # Same as input.