def train_step(optimizer, inputs, learning_rate_fn, dropout_rng, preprocess_fn, example_weights=None, grad_clip=None, epsilon=1e-9): """Performs a single training step. Masks out BOS/PAD positions. Args: optimizer: Flax optimizer. inputs: Inputs to model.preprocess which returns (inputs, targets, weights). learning_rate_fn: function from step idx --> learning rate. dropout_rng: RNG for dropout. preprocess_fn: function mapping (inputs, rng, mode) -> (inputs, targets, weights). example_weights: Optional [batch] weights for the loss on each example. See utils.compute_weighted_cross_entropy for details. grad_clip: If not None, clip gradients to [-x, +x]. epsilon: Epsilon for denominator of loss averaging. Returns: new_optimizer, metrics, new_dropout_rng """ # We handle PRNG splitting inside the top pmap, rather # than handling it outside in the training loop - doing the # latter can add some stalls to the devices. dropout_rng, new_dropout_rng = jrandom.split(dropout_rng) dropout_rng, preprocess_rng = jrandom.split(dropout_rng) inputs, targets, weights = preprocess_fn( inputs, rng=preprocess_rng, mode=Mode.train) if isinstance(targets, dict): classification_targets = targets['classification'] classification_weights = weights['classification'] regression_targets = targets['regression'] regression_weights = weights['regression'] else: # Default to classification loss. classification_targets = targets classification_weights = weights regression_targets = None if classification_targets is None and regression_targets is None: raise ValueError('No targets specified for train step.') if classification_weights is None and regression_weights is None: raise ValueError('No weights specified for train step.') def loss_fn(model): """Loss function used for training.""" # Stateful collection for tracking internal state like activations. with nn.stateful() as batch_stats: with nn.stochastic(dropout_rng): outputs = model(inputs, train=True, cache=None) if isinstance(outputs, dict): logits = outputs.get('logits', None) regression_predictions = outputs.get('regression', None) else: logits = outputs regression_predictions = None mean_loss = 0.0 # Classification loss if classification_targets is not None: classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy( logits, classification_targets, token_weights=classification_weights, example_weights=example_weights) classification_weight_sum = jnp.maximum(classification_weight_sum, epsilon) # Handle case where nothing is masked out in BERT # (Only occurs with very short sequences). mean_classification_loss = classification_loss / classification_weight_sum mean_loss += mean_classification_loss if regression_targets is not None: regression_loss, regression_weight_sum = utils.compute_weighted_mse( regression_predictions, regression_targets, weights=regression_weights) regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon) mean_regression_loss = regression_loss / regression_weight_sum outputs['regression_loss'] = mean_regression_loss # TODO(ddohan): Allow weighting each loss separately. mean_loss += mean_regression_loss return mean_loss, (outputs, batch_stats) step = optimizer.state.step lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, (outputs, batch_stats)), grad = grad_fn(optimizer.target) try: grad = jax.lax.pmean(grad, 'batch') except NameError: pass if grad_clip is not None: # Clip gradients after pmean aggregation unclipped_grad = grad grad = jax.example_libraries.optimizers.clip_grads(grad, grad_clip) new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr) # TODO(ddohan): Avoid computing metrics except when needed. if isinstance(outputs, dict): logits = outputs.get('logits', None) else: logits = outputs metrics = dict() if logits is not None: classification_metrics = utils.compute_metrics(logits, classification_targets, classification_weights) metrics.update(classification_metrics) if regression_targets is not None: # TODO(ddohan): Implement regression metrics. logging.info('No regression targets yet') # regression = outputs.get('regression', None) # regression_metrics = utils.compute_metrics(logits, regression_targets, # classification_weights) metrics['learning_rate'] = lr # Training metrics metrics['l2_param_sum'] = utils.l2_regularization(optimizer.target.params) # Gradient norms grad_l2_tree = utils.l2_norm(grad) grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree) grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree) metrics['l2_grad_sum'] = grad_l2_sum metrics['l2_grad_max'] = grad_l2_max # Store any tagged metrics batch_stats = batch_stats.as_dict() if batch_stats: def clean_name(k): return 'nn/' + k.replace('MultiHeadDotProductAttention_', '').replace( '/Transformer1DBlock_', '') stats = {clean_name(k): v['tag'] for k, v in batch_stats.items()} metrics.update(stats) if grad_clip is not None: # Unclipped gradient norms (if applicable). grad_l2_tree = utils.l2_norm(unclipped_grad) grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree) grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree) metrics['l2_noclip_grad_sum'] = grad_l2_sum metrics['l2_noclip_grad_max'] = grad_l2_max return new_optimizer, metrics, new_dropout_rng
def train_step(optimizer, inputs, learning_rate_fn, example_weights=None, dropout_rng=None, grad_clip=None, bos_token=0): """Performs a single training step. Masks out BOS/PAD positions. Args: optimizer: Flax optimizer. inputs: [batch x length] inputs. learning_rate_fn: function from step idx --> learning rate. example_weights: Optional [batch] weights for the loss on each example. See utils.compute_weighted_cross_entropy for details. dropout_rng: RNG for dropout. grad_clip: If not None, clip gradients to [-x, +x]. bos_token: Beginning of sentence token used to generate weight mask. Returns: new_optimizer, metrics, new_dropout_rng """ # BOS token is equal to PAD when seen on output (loss) side, so this masks # out both BOS and PAD positions. token_weights = jnp.where(inputs != bos_token, 1, 0) # We handle PRNG splitting inside the top pmap, rather # than handling it outside in the training loop - doing the # latter can add some stalls to the devices. dropout_rng, new_dropout_rng = random.split(dropout_rng) def loss_fn(model): """Loss function used for training.""" with nn.stochastic(dropout_rng): logits = model(inputs, train=True, cache=None) loss, weight_sum = utils.compute_weighted_cross_entropy( logits, inputs, token_weights=token_weights, example_weights=example_weights) mean_loss = loss / weight_sum return mean_loss, logits step = optimizer.state.step lr = learning_rate_fn(step) # Get gradient grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grad = grad_fn(optimizer.target) grad = jax.lax.pmean(grad, 'batch') # Compute metrics from forward pass metrics = utils.compute_metrics(logits, inputs, token_weights) metrics['learning_rate'] = lr metrics['l2_param_sum'] = utils.l2_regularization(optimizer.target.params) # Gradient norms grad_l2_tree = utils.l2_norm(grad) grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree) grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree) metrics['grad_l2_sum'] = grad_l2_sum metrics['grad_l2_max'] = grad_l2_max # TODO(ddohan): Clip by global grad norm. if grad_clip is not None: # Clip gradients after pmean aggregation clip = lambda g: jnp.clip(g, -grad_clip, grad_clip) # pylint: disable=invalid-unary-operand-type grad = jax.tree_util.tree_map(clip, grad) # Metrics for clipped grads. clipped_grad_l2_tree = utils.l2_norm(grad) clipped_grad_l2_sum = jax.tree_util.tree_reduce( op.add, clipped_grad_l2_tree) clipped_grad_l2_max = jax.tree_util.tree_reduce( jnp.maximum, clipped_grad_l2_tree) metrics['gradclip_l2_sum'] = clipped_grad_l2_sum metrics['gradclip_l2_max'] = clipped_grad_l2_max # Apply gradients and return new optimizer new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr) return new_optimizer, metrics, new_dropout_rng