Example #1
0
def train_step(optimizer,
               inputs,
               learning_rate_fn,
               dropout_rng,
               preprocess_fn,
               example_weights=None,
               grad_clip=None,
               epsilon=1e-9):
  """Performs a single training step. Masks out BOS/PAD positions.

  Args:
    optimizer: Flax optimizer.
    inputs: Inputs to model.preprocess which returns (inputs, targets, weights).
    learning_rate_fn: function from step idx --> learning rate.
    dropout_rng: RNG for dropout.
    preprocess_fn: function mapping
      (inputs, rng, mode) -> (inputs, targets, weights).
    example_weights: Optional [batch] weights for the loss on each example.
      See utils.compute_weighted_cross_entropy for details.
    grad_clip: If not None, clip gradients to [-x, +x].
    epsilon: Epsilon for denominator of loss averaging.

  Returns:
    new_optimizer, metrics, new_dropout_rng
  """

  # We handle PRNG splitting inside the top pmap, rather
  # than handling it outside in the training loop - doing the
  # latter can add some stalls to the devices.
  dropout_rng, new_dropout_rng = jrandom.split(dropout_rng)
  dropout_rng, preprocess_rng = jrandom.split(dropout_rng)

  inputs, targets, weights = preprocess_fn(
      inputs, rng=preprocess_rng, mode=Mode.train)

  if isinstance(targets, dict):
    classification_targets = targets['classification']
    classification_weights = weights['classification']

    regression_targets = targets['regression']
    regression_weights = weights['regression']
  else:
    # Default to classification loss.
    classification_targets = targets
    classification_weights = weights
    regression_targets = None

  if classification_targets is None and regression_targets is None:
    raise ValueError('No targets specified for train step.')

  if classification_weights is None and regression_weights is None:
    raise ValueError('No weights specified for train step.')

  def loss_fn(model):
    """Loss function used for training."""
    # Stateful collection for tracking internal state like activations.
    with nn.stateful() as batch_stats:
      with nn.stochastic(dropout_rng):
        outputs = model(inputs, train=True, cache=None)

      if isinstance(outputs, dict):
        logits = outputs.get('logits', None)
        regression_predictions = outputs.get('regression', None)
      else:
        logits = outputs
        regression_predictions = None

    mean_loss = 0.0

    # Classification loss
    if classification_targets is not None:
      classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy(
          logits,
          classification_targets,
          token_weights=classification_weights,
          example_weights=example_weights)
      classification_weight_sum = jnp.maximum(classification_weight_sum,
                                              epsilon)
      # Handle case where nothing is masked out in BERT
      # (Only occurs with very short sequences).
      mean_classification_loss = classification_loss / classification_weight_sum
      mean_loss += mean_classification_loss

    if regression_targets is not None:
      regression_loss, regression_weight_sum = utils.compute_weighted_mse(
          regression_predictions,
          regression_targets,
          weights=regression_weights)
      regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon)
      mean_regression_loss = regression_loss / regression_weight_sum
      outputs['regression_loss'] = mean_regression_loss

      # TODO(ddohan): Allow weighting each loss separately.
      mean_loss += mean_regression_loss

    return mean_loss, (outputs, batch_stats)

  step = optimizer.state.step
  lr = learning_rate_fn(step)

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, (outputs, batch_stats)), grad = grad_fn(optimizer.target)

  try:
    grad = jax.lax.pmean(grad, 'batch')
  except NameError:
    pass

  if grad_clip is not None:
    # Clip gradients after pmean aggregation
    unclipped_grad = grad
    grad = jax.example_libraries.optimizers.clip_grads(grad, grad_clip)

  new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)

  # TODO(ddohan): Avoid computing metrics except when needed.
  if isinstance(outputs, dict):
    logits = outputs.get('logits', None)
  else:
    logits = outputs

  metrics = dict()
  if logits is not None:
    classification_metrics = utils.compute_metrics(logits,
                                                   classification_targets,
                                                   classification_weights)
    metrics.update(classification_metrics)
  if regression_targets is not None:
    # TODO(ddohan): Implement regression metrics.
    logging.info('No regression targets yet')
    # regression = outputs.get('regression', None)
    # regression_metrics = utils.compute_metrics(logits, regression_targets,
    #                                                classification_weights)
  metrics['learning_rate'] = lr

  # Training metrics
  metrics['l2_param_sum'] = utils.l2_regularization(optimizer.target.params)

  # Gradient norms
  grad_l2_tree = utils.l2_norm(grad)
  grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)
  grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)
  metrics['l2_grad_sum'] = grad_l2_sum
  metrics['l2_grad_max'] = grad_l2_max

  # Store any tagged metrics
  batch_stats = batch_stats.as_dict()
  if batch_stats:

    def clean_name(k):
      return 'nn/' + k.replace('MultiHeadDotProductAttention_', '').replace(
          '/Transformer1DBlock_', '')

    stats = {clean_name(k): v['tag'] for k, v in batch_stats.items()}
    metrics.update(stats)

  if grad_clip is not None:
    # Unclipped gradient norms (if applicable).
    grad_l2_tree = utils.l2_norm(unclipped_grad)
    grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)
    grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)
    metrics['l2_noclip_grad_sum'] = grad_l2_sum
    metrics['l2_noclip_grad_max'] = grad_l2_max

  return new_optimizer, metrics, new_dropout_rng
Example #2
0
def train_step(optimizer,
               inputs,
               learning_rate_fn,
               example_weights=None,
               dropout_rng=None,
               grad_clip=None,
               bos_token=0):
    """Performs a single training step. Masks out BOS/PAD positions.

  Args:
    optimizer: Flax optimizer.
    inputs: [batch x length] inputs.
    learning_rate_fn: function from step idx --> learning rate.
    example_weights: Optional [batch] weights for the loss on each example.
      See utils.compute_weighted_cross_entropy for details.
    dropout_rng: RNG for dropout.
    grad_clip: If not None, clip gradients to [-x, +x].
    bos_token: Beginning of sentence token used to generate weight mask.

  Returns:
    new_optimizer, metrics, new_dropout_rng
  """

    # BOS token is equal to PAD when seen on output (loss) side, so this masks
    # out both BOS and PAD positions.
    token_weights = jnp.where(inputs != bos_token, 1, 0)

    # We handle PRNG splitting inside the top pmap, rather
    # than handling it outside in the training loop - doing the
    # latter can add some stalls to the devices.
    dropout_rng, new_dropout_rng = random.split(dropout_rng)

    def loss_fn(model):
        """Loss function used for training."""
        with nn.stochastic(dropout_rng):
            logits = model(inputs, train=True, cache=None)
        loss, weight_sum = utils.compute_weighted_cross_entropy(
            logits,
            inputs,
            token_weights=token_weights,
            example_weights=example_weights)
        mean_loss = loss / weight_sum
        return mean_loss, logits

    step = optimizer.state.step
    lr = learning_rate_fn(step)

    # Get gradient
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grad = grad_fn(optimizer.target)
    grad = jax.lax.pmean(grad, 'batch')

    # Compute metrics from forward pass
    metrics = utils.compute_metrics(logits, inputs, token_weights)
    metrics['learning_rate'] = lr

    metrics['l2_param_sum'] = utils.l2_regularization(optimizer.target.params)

    # Gradient norms
    grad_l2_tree = utils.l2_norm(grad)
    grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)
    grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)
    metrics['grad_l2_sum'] = grad_l2_sum
    metrics['grad_l2_max'] = grad_l2_max

    # TODO(ddohan): Clip by global grad norm.
    if grad_clip is not None:
        # Clip gradients after pmean aggregation
        clip = lambda g: jnp.clip(g, -grad_clip, grad_clip)  # pylint: disable=invalid-unary-operand-type
        grad = jax.tree_util.tree_map(clip, grad)

        # Metrics for clipped grads.
        clipped_grad_l2_tree = utils.l2_norm(grad)
        clipped_grad_l2_sum = jax.tree_util.tree_reduce(
            op.add, clipped_grad_l2_tree)
        clipped_grad_l2_max = jax.tree_util.tree_reduce(
            jnp.maximum, clipped_grad_l2_tree)
        metrics['gradclip_l2_sum'] = clipped_grad_l2_sum
        metrics['gradclip_l2_max'] = clipped_grad_l2_max

    # Apply gradients and return new optimizer
    new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
    return new_optimizer, metrics, new_dropout_rng