Exemple #1
0
    def training_cost(self, flax_module, batch_stats, batch, dropout_rng):
        """Return cross entropy loss with (optional) L2 penalty on the weights."""

        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(dropout_rng):
                # inputs/targets positions and segmentations are required
                # when we have packed examples.
                logits = flax_module(batch['inputs'],
                                     batch['targets'],
                                     batch.get('inputs_positions'),
                                     batch.get('targets_positions'),
                                     batch.get('inputs_segmentation'),
                                     batch.get('targets_segmentation'),
                                     train=True)

        weights = batch.get('weights')
        targets = batch['targets']

        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])
        # Optionally apply label smoothing.
        if self.hps.get('label_smoothing') is not None:
            targets = model_utils.apply_label_smoothing(
                targets, self.hps.get('label_smoothing'))
        total_loss = self.loss_fn(logits, targets, weights)

        if self.hps.get('l2_decay_factor'):
            l2_loss = model_utils.l2_regularization(
                flax_module.params, self.hps.l2_decay_rank_threshold)
            total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss
        return total_loss, (new_batch_stats)
Exemple #2
0
    def training_cost(self, flax_module, batch_stats, batch, dropout_rng):
        """Return loss with an L2 penalty on the weights."""
        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(dropout_rng):
                logits = flax_module(batch['inputs'], train=True)
        weights = batch.get('weights')
        targets = batch['targets']
        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(targets, logits.shape[-1])
        # Optionally apply label smoothing.
        if self.hps.get('label_smoothing') is not None:
            targets = model_utils.apply_label_smoothing(
                targets, self.hps.get('label_smoothing'))
        total_loss = self.loss_fn(logits, targets, weights)

        if self.hps.get('l2_decay_factor'):
            l2_loss = model_utils.l2_regularization(
                flax_module.params, self.hps.l2_decay_rank_threshold)
            total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss

        return total_loss, (new_batch_stats)
Exemple #3
0
    def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
        """Return cross entropy loss with (optional) L2 penalty on the weights."""

        # inputs/targets positions and segmentations are required when we have
        # packed examples.
        logits, new_batch_stats = self.flax_module.apply(
            {
                'params': params,
                'batch_stats': batch_stats
            },
            batch['inputs'],
            batch['targets'],
            batch.get('inputs_positions'),
            batch.get('targets_positions'),
            batch.get('inputs_segmentation'),
            batch.get('targets_segmentation'),
            mutable=['batch_stats'],
            rngs={'dropout': dropout_rng},
            train=True)

        weights = batch.get('weights')
        targets = batch['targets']

        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])
        # Optionally apply label smoothing.
        if self.hps.get('label_smoothing') is not None:
            targets = model_utils.apply_label_smoothing(
                targets, self.hps.get('label_smoothing'))
        total_loss = self.loss_fn(logits, targets, weights)

        if self.hps.get('l2_decay_factor'):
            l2_loss = model_utils.l2_regularization(
                params, self.hps.l2_decay_rank_threshold)
            total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss
        return total_loss, (new_batch_stats)
Exemple #4
0
def update(optimizer_state, params, batch_stats, metrics_state, batch, step,
           lr, rng, local_device_index, running_train_cost, training_cost,
           grad_clip, optimizer_update_fn, metrics_update_fn):
    """Single step of the training loop.

  Args:
    optimizer_state: the optax optimizer state.
    params: a dict of trainable model parameters. Passed into training_cost(...)
      which then passes into flax_module.apply() as {'params': params} as part
      of the variables dict.
    batch_stats: a dict of non-trainable model state. Passed into
      training_cost(...) which then passes into flax_module.apply() as
      {'batch_stats': batch_stats} as part of the variables dict.
    metrics_state: a pytree of training metrics state.
    batch: the per-device batch of data to process.
    step: the current global step of this update. Used to fold in to `rng` to
      produce a unique per-device, per-step RNG.
    lr: the floating point learning rate for this step.
    rng: the RNG used for calling the model. `step` and `local_device_index`
      will be folded into this to produce a unique per-device, per-step RNG.
    local_device_index: an integer that is unique to this device amongst all
      devices on this host, usually in the range [0, jax.local_device_count()].
      It is folded in to `rng` to produce a unique per-device, per-step RNG.
    running_train_cost: the cumulative train cost over some past number of train
      steps. Reset at evaluation time.
    training_cost: a function used to calculate the training objective that will
      be differentiated to generate updates. Takes
      (`params`, `batch`, `batch_stats`, `dropout_rng`) as inputs.
    grad_clip: Clip the l2 norm of the gradient at the specified value. For
      minibatches with gradient norm ||g||_2 > grad_clip, we rescale g to the
      value g / ||g||_2 * grad_clip. If None, then no clipping will be applied.
    optimizer_update_fn: the optimizer update function.
    metrics_update_fn: the training metrics update function.

  Returns:
    A tuple of the new optimizer, the new batch stats, the scalar training cost,
    the new training metrics state, and the gradient norm.
  """
    # `jax.random.split` is very slow outside the train step, so instead we do a
    # `jax.random.fold_in` here.
    rng = jax.random.fold_in(rng, step)
    rng = jax.random.fold_in(rng, local_device_index)

    _inject_learning_rate(optimizer_state, lr)

    def opt_cost(params):
        return training_cost(params,
                             batch=batch,
                             batch_stats=batch_stats,
                             dropout_rng=rng)

    grad_fn = jax.value_and_grad(opt_cost, has_aux=True)
    (cost_value, new_batch_stats), grad = grad_fn(params)
    new_batch_stats = new_batch_stats.get('batch_stats', None)

    cost_value, grad = lax.pmean((cost_value, grad), axis_name='batch')

    grad_norm = jnp.sqrt(model_utils.l2_regularization(grad, 0))
    # TODO(znado): move to inside optax gradient clipping.
    if grad_clip:
        scaled_grad = jax.tree_map(
            lambda x: x / (grad_norm + _GRAD_CLIP_EPS) * grad_clip, grad)
        grad = jax.lax.cond(grad_norm > grad_clip, lambda _: scaled_grad,
                            lambda _: grad, None)
    model_updates, new_optimizer_state = optimizer_update_fn(
        grad,
        optimizer_state,
        params=params,
        batch=batch,
        batch_stats=new_batch_stats)
    new_params = optax.apply_updates(params, model_updates)

    new_metrics_state = None
    if metrics_state is not None:
        new_metrics_state = metrics_update_fn(metrics_state, step, cost_value,
                                              grad, params, new_params,
                                              optimizer_state)

    return (new_optimizer_state, new_params, new_batch_stats,
            running_train_cost + cost_value, new_metrics_state, grad_norm)