def cell_fn(carry_0, xs_t): # theta is implicitly captured. y = theta.delta + xs_t.x + carry_0.y z = y + 1 carry_1 = NestedMap(y=y) base_layer.add_summary('test_summary', z) return carry_1, NestedMap(z=z)
def _forward_summary(self, summaries): """Forwards summary from the inner JaxContext to the outer context.""" p = self.params for summary_key, summary_value in summaries.items(): logging.info((summary_key, summary_value)) summary_type = base_layer.get_summary_type_from_key(summary_key) assert summary_value.shape[0] == p.num_stages if p.unpack_summaries: # unstack summary_value unstacked_values = jnp.split(summary_value, p.num_stages) for i, v in enumerate(unstacked_values): base_layer.add_summary(f'{summary_key}/{i}', v, summary_type) else: base_layer.add_summary('{summary_key}', summary_value, summary_type)
def apply_gradient( self, old_vars: NestedJTensor, transformed_grads: NestedJTensor, var_is_learnable: NestedBool, ) -> NestedJTensor: """Applies grads to model_variables. Note, in a flax model learnable variables are often referred to as 'params'. But since 'params' in Lingvo often refers to a hyperparams.Params, we refer to learnable weights of a network as 'variables'. Args: old_vars: a nested structure of model variables. transformed_grads: grads of loss wrt to the old_vars. Must be of the same structure as old_var. 'transformed_grads' have already gone through various gradient transformations. var_is_learnable: a nested structure of boolean values indicate whether a var is trainable. Must be of the same structure as old_vars. 'non-trainable' vars include batch norm stats, various other counts, etc. Only learnable variables are updated. Returns: updated variables. Only learnable variables are updated. """ p = self.params tf.nest.assert_same_structure(old_vars, transformed_grads) tf.nest.assert_same_structure(old_vars, var_is_learnable) assert p.skip_zero_gradients is None # Add a summary of total var norm. var_squared = jax.tree_map(lambda x: jnp.sum(x * x), old_vars) var_squared, _ = jax.tree_flatten(var_squared) var_squared = jnp.concatenate([x[jnp.newaxis] for x in var_squared]) var_norm = jnp.sqrt(jnp.sum(var_squared)) base_layer.add_summary('var_norm', var_norm) # TODO(yonghui): implement skip_zero_gradients. # TODO(yonghui): implement numerical checks. def _adjust_var(old_var, transformed_grad, is_learnable): if is_learnable: return old_var + transformed_grad else: return old_var return tf.nest.map_structure(_adjust_var, old_vars, transformed_grads, var_is_learnable)
def compute_loss(self, predictions: NestedMap, input_batch: NestedMap) -> Tuple[Metrics, Dict[str, Any]]: """Computes the loss and other metrics for the given predictions. Args: predictions: The output of `compute_predictions`. input_batch: A `.NestedMap` object containing input tensors to this tower. Returns: - A dict or NestedMap containing str keys and (metric, weight) pairs as values, where one of the entries is expected to correspond to the loss. - A dict containing arbitrary tensors describing something about each training example, where the first dimension of each tensor is the batch index. The base class just returns an empty dict. """ avg_xent = predictions.softmax_output.avg_xent total_weight = predictions.softmax_output.total_weight metrics = NestedMap( avg_xent=(avg_xent, total_weight), num_predictions=(total_weight, jnp.array(1.0, total_weight.dtype))) # Compute top-1 and top-5 accuracy and add summary. acc1 = metric_utils.top_k_accuracy( 1, predictions.softmax_output.logits, label_probs=input_batch.label_probs, weights=predictions.example_weights) acc5 = metric_utils.top_k_accuracy( 5, predictions.softmax_output.logits, label_probs=input_batch.label_probs, weights=predictions.example_weights) metrics.update( accuracy=(acc1, predictions.softmax_output.total_weight), acc5=(acc5, predictions.softmax_output.total_weight), error=(1.0 - acc1, predictions.softmax_output.total_weight), error5=(1.0 - acc5, predictions.softmax_output.total_weight)) # Add top-1 and top-5 accuracies to summaries. base_layer.add_summary('acc1', acc1) base_layer.add_summary('acc5', acc5) return metrics, {}
def fprop(self, inputs: JTensor, class_weights: JTensor, class_ids: Optional[JTensor] = None, class_probabilities: Optional[JTensor] = None) -> NestedMap: """Computes logits, cross entropy etc. Args: inputs: a single JTensor with shape [..., input_dim]. class_weights: a JTensor with shape [..., 1] containing the weights for each target word. class_ids: a JTensor with shape [..., 1] of int32 dtype containing the target class labels. class_probabilities: a JTensor with shape [..., num_classes] of float values indicating class-membership probabilities. Returns: A `.NestedMap` containing the following fields - logits: with shape [..., num_classes]. Unnormalized softmax's logits. - per_example_argmax: with shape [...]. argmax of i-th example. - per_example_xent: with shape [...]. Cross entropy between i-th example's prediction and its label. - per_example_weight: with shape [...]. class_weights casted to this layer's dtype. - total_xent: A scalar. The sum of per_example_weight * per_example_xent. - total_weight: A scalar. The sum of per_example_weight. - avg_xent: A scalar. total_loss / total_weight. """ p = self.params # Assert one of class_ids or class_probabilities is not None if class_ids is None and class_probabilities is None: raise ValueError( 'One of class_ids or class_probabilities must be given.') # Compute logits inputs_dtype = inputs.dtype logits = self.get_logits(inputs) # We perform softmax in float32 to improve stability. logits = logits.astype(jnp.float32) log_probs = jax.nn.log_softmax(logits) if class_probabilities is None: class_probabilities = jax.nn.one_hot( jnp.squeeze(class_ids, axis=-1), p.num_classes) class_probabilities = jax.lax.stop_gradient(class_probabilities) per_example_xent = -jnp.sum(log_probs * class_probabilities, axis=-1) per_example_argmax = jax.lax.stop_gradient(jnp.argmax(logits, axis=-1)) # Compute total softmax for the entire sequence total_xent = jnp.sum( jnp.expand_dims(per_example_xent, axis=-1) * class_weights) total_weight = jnp.sum(class_weights) if p.use_tgt_labels_size_as_loss_denominator: loss_denominator = jnp.sum(jnp.ones_like(class_weights)) else: loss_denominator = total_weight avg_xent = (total_xent / loss_denominator).astype(inputs_dtype) z_loss = (jnp.sum(self.compute_z_loss(logits) * class_weights) / loss_denominator) z_loss *= p.z_loss_weight base_layer.add_summary('aux_z_loss', z_loss) aux_loss_ctx = py_utils.AuxLossContext.Current() if aux_loss_ctx is not None: aux_loss_ctx.AddLoss(z_loss) output_nmap = NestedMap( logits=logits.astype(inputs_dtype), log_probs=log_probs.astype(inputs_dtype), per_example_argmax=per_example_argmax.astype(inputs_dtype), per_example_xent=per_example_xent.astype(inputs_dtype), total_xent=total_xent.astype(inputs_dtype), # base_model.py _compute_xent_loss_helper uses avg_xent_weight if set, # this helper is currently used by LanguageModel only, if we have # EncoderDecoder model we will have to adjust weighting as well. avg_xent_weight=loss_denominator, avg_xent=avg_xent, total_weight=total_weight) return output_nmap
def train_step_single_learner( jax_task: base_task.SingleTask, states: TrainState, prng_key: JTensor, inputs: Union[JTensor, NestedMap], data_parallel_axis_name: Optional[str] = 'batch', fprop_dtype: jnp.dtype = jnp.float32 ) -> Tuple[TrainState, Any, Any, Any, SummaryDict]: """Trains a model for a single step. This function works for both pmap-ed model and pjit-ed model. When this function is called from pmap-ed trainer, data_parallel_axis_name has to be set. Otherwise, data_parallel_axis_name should be either an empty string or None. TODO(yonghui): Maybe refactor pmap and pjit into two functions. This utility is specialized for the singler learner case. Args: jax_task: An instance of base_task.SingleTask. states: An instance of model.TrainState. prng_key: A PRNGKey, of shape [2], of type np.uint32. inputs: Inputs to the mdl.fprop() function. data_parallel_axis_name: a string, the device axis to aggregate gradients over. fprop_dtype: fprop datatype, can be either jnp.float32 or jnp.bfloat16. Returns: A tuple of the following elements. updated_states - updated states. loss - loss as computed by mdl.fprop. mean_metrics - a dict of metrics. Each element of the dict is a pair (metric, weight). per_example_out - auxilillary per-example output as computed in mdl.fprop. summary_tensors - A dict or nested map of summary tensors computed in forward as well as backward. """ assert len(jax_task.learners) == 1 learner = jax_task.learners[0] model = jax_task.model context_p = base_layer.JaxContext.Params().Set(do_eval=False) # Fold in global_step as part of the random seed key, so that random # numbers depends on global step. prng_key = jax.random.fold_in(prng_key, states.step) if data_parallel_axis_name: in_pmap = True else: in_pmap = False prng_key, subkey = jax.random.split(prng_key) def _loss_fn( mdl_vars: NestedJTensor, inputs: NestedMap ) -> Tuple[JTensor, Tuple[Any, NestedMap, SummaryDict, SummaryDict]]: """Computes loss as well as other auxiliary outputs.""" if fprop_dtype == jnp.float32: pass elif fprop_dtype == jnp.bfloat16: mdl_vars = jax.tree_map(_maybe_to_bfloat16, mdl_vars) inputs = jax.tree_map(_maybe_to_bfloat16, inputs) else: assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.') with base_layer.JaxContext.new_context( params=context_p, prng_key=subkey, global_step=states.step) as jax_context: jax_context.bind(model, model.vars_to_flax_vars(mdl_vars), [base_layer.SCOPE_VARS, base_layer.SCOPE_AUX_LOSS]) metrics, per_example_output = model.fprop(inputs) loss_name = learner.loss_name assert loss_name in metrics loss, loss_weight = metrics[loss_name] assert loss.ndim == 0, 'loss has to be a scalar.' assert loss_weight.ndim == 0, 'loss_weight has to be a scalar' loss_weight = jax.lax.stop_gradient(loss_weight) if in_pmap: # Renormalize loss weight by the total weight across all replicas. # This also takes care of dividing loss by num of data parallel # replicas. loss_weight /= jax.lax.psum( loss_weight, axis_name=data_parallel_axis_name) else: # loss_weight == 1 in spmd. loss_weight /= jnp.sum(loss_weight) weighted_loss = loss * loss_weight # Fetch forward-updated vars, which often include batch norm vars, other # misc stats, etc. forward_updated_vars = model.updated_vars # Finally, fetch all the summary tensors. summary_tensors = base_layer.all_summaries() if in_pmap: summary_tensors = summary_utils.aggregate_per_replica_summaries( summary_tensors, data_parallel_axis_name) if fprop_dtype == jnp.bfloat16 and weighted_loss.dtype == fprop_dtype: weighted_loss = weighted_loss.astype(jnp.float32) return weighted_loss, (metrics, forward_updated_vars, summary_tensors, per_example_output) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (weighted_loss, (metrics, fwd_updated_vars, fwd_summary_tensors, per_example_out)), grads = grad_fn(states.mdl_vars, inputs) if in_pmap: # Scale weighted_loss back after gradient computation. # This is the average loss across all replicas. weighted_loss = jax.lax.psum( weighted_loss, axis_name=data_parallel_axis_name) else: # No sum of loss over all the replicas. pass if in_pmap: mean_metrics = type(metrics)() for key in metrics: value, weight = metrics[key] sum_value = jax.lax.psum( value * weight, axis_name=data_parallel_axis_name) sum_weight = jax.lax.psum(weight, axis_name=data_parallel_axis_name) mean_metrics[key] = (sum_value / (sum_weight + 1e-8), sum_weight) else: # No aggregation is needed. mean_metrics = metrics var_weight_params = model.vars tf.nest.assert_same_structure(states.mdl_vars, var_weight_params) tf.nest.assert_same_structure(states.mdl_vars, grads) tf.nest.assert_same_structure(states.mdl_vars, fwd_updated_vars) var_is_learnable = tf.nest.map_structure( lambda x: not base_layer.var_not_trainable(x), var_weight_params) tf.nest.assert_same_structure(states.mdl_vars, var_is_learnable) def _maybe_zero_out_grad_fn(var_grad, var, var_learnable): if var_learnable: return var_grad else: return jnp.zeros_like(var) # Zero-out gradient for non-learnable vars. grads = tf.nest.map_structure(_maybe_zero_out_grad_fn, grads, states.mdl_vars, var_is_learnable) if in_pmap: # Aggregate grads across different model replicas. grads = jax.lax.psum(grads, axis_name=data_parallel_axis_name) else: # No gradient aggregation is needed. pass # Carry out backward computation under a JaxContext. prng_key, subkey = jax.random.split(prng_key) with base_layer.JaxContext.new_context( params=context_p, prng_key=subkey, global_step=states.step) as jax_context: # Nothing is allowed to change, except for summaries. jax_context.bind(model, model.vars_to_flax_vars(states.mdl_vars), [base_layer.SCOPE_AUX_LOSS]) # Add a summary for learning rate learning_rate = learner.optimizer.get_learning_rate(states.step) base_layer.add_summary('learning_rate', learning_rate) # Apply gradient transformations. transformed_grads, new_opt_states = learner.update_states( grads, states.opt_states[0], states.mdl_vars, var_weight_params) # Gradient descent on learnable vars. mdl_vars = learner.apply_gradient(states.mdl_vars, transformed_grads, var_is_learnable) def _synchronize_vars_using_mean(new_var: NestedMap, old_var: NestedMap) -> NestedMap: """Synchronize variables across a replica by averaging.""" delta = new_var - old_var delta_mean = jax.lax.pmean(delta, axis_name=data_parallel_axis_name) updated_var = old_var + delta_mean return updated_var def _update_non_learnable_var(old_var: NestedMap, new_var: NestedMap, var_params: ParamsT) -> NestedMap: """Update non-trainable variables, using cross-replica synchronization. Args: old_var: Nested map of old variables. new_var: Nested map of new variables. var_params: Nested map of var param attributes such as whether a variable is trainable or requires synchornization across replicas. Returns: Updated variables. Raises: ValueError if no synchronization method is provided for non-trainable variables. """ if not base_layer.var_not_trainable(var_params): assert new_var is None return old_var elif not in_pmap: # No aggregation is needed. assert new_var is not None return new_var elif base_layer.var_requires_mean_sync(var_params): assert new_var is not None return _synchronize_vars_using_mean(new_var, old_var) else: raise ValueError('Non-trainable variables must have a cross-replica ' 'synchronization method specified.') var_weight_params = model.vars tf.nest.assert_same_structure(mdl_vars, var_weight_params) mdl_vars = tf.nest.map_structure(_update_non_learnable_var, mdl_vars, fwd_updated_vars, var_weight_params) new_states = states.new_state( mdl_vars=mdl_vars, opt_states=[new_opt_states]) # Finally fetch all backward summary tensors. We do not aggregate the scalar # summaries with pmean because the grads are already psum-ed. bwd_summary_tensors = base_layer.all_summaries() summary_tensors = NestedMap( fwd_summary_tensors=fwd_summary_tensors, bwd_summary_tensors=bwd_summary_tensors) return (new_states, weighted_loss, mean_metrics, per_example_out, summary_tensors)
def fprop(self, inputs: JTensor, paddings: Optional[JTensor] = None) -> Tuple[JTensor, JTensor]: """Computes distances of the given input 'x' to all centroids. Args: inputs: Input tensor of shape [B, L, N, H] or [B, L, D]. paddings: If not None, a tensor of shape [B, L]. The padding tensor is supplied when we want certain tokens to not affect the centroids. Returns: dists: "distances" of the given input 'x' to all centroids. Shape [B, L, N, K]. nearest_centroid: The inputs with the input embeddings replaced by the centroid embeddings, it has the same shape as the inputs i.e., [B, L, N, H]. """ p = self.params theta = self.local_theta() inputs = self._cast_to_fprop_dtype(inputs) inputs_shape = inputs.shape if len(inputs_shape) == 3: inputs = jnp.reshape(inputs, [ inputs_shape[0], inputs_shape[1], p.num_heads, p.dim_per_head ]) if paddings is not None: # Shape [B, L, 1, 1] paddings_4d = paddings[:, :, jnp.newaxis, jnp.newaxis] dists = -2 * jnp.einsum('BLNH, NKH -> BLNK', inputs, theta.means) # [B, L, N, 1] inputs_norm_sq = jnp.sum(jnp.square(inputs), axis=-1, keepdims=True) # [N, K] means_norm_sq = jnp.sum(jnp.square(theta.means), axis=-1, keepdims=False) # [1, 1, N, K] means_norm_sq = means_norm_sq[jnp.newaxis, jnp.newaxis, :, :] dists += inputs_norm_sq + means_norm_sq # Shape [B, L, N, K], the same as 'dists' above. nearest_one_hot = jax.nn.one_hot(jnp.argmin(dists, axis=-1), p.num_clusters, dtype=theta.means.dtype) # Apply paddings. if paddings is not None: nearest_one_hot *= (1 - paddings_4d) # Same shape as the input [B, L, N, H]. nearest_centroid = jnp.einsum('BLNK, NKH -> BLNH', nearest_one_hot, theta.means) means_norm = jnp.linalg.norm(theta.means, ord=2, axis=-1) base_layer.add_summary('k_means/centroid/l2_norm_avg', jnp.mean(means_norm)) base_layer.add_summary('k_means/centroid/l2_norm_min', jnp.min(means_norm)) base_layer.add_summary('k_means/centroid/l2_norm_max', jnp.max(means_norm)) if not self.do_eval: # To update the centroids (self.vars.means), we apply gradient descent on # the mini-batch of input, which yields the following: # new_centroid = centroid + (1 - decay) * (x_mean - centroid) # where x_mean is the average over all the input vectors closest to this # centroid. # Sum away batch and sequence length dimensions to get per cluster count. # Shape: [N, K] per_cluster_count = jnp.sum(nearest_one_hot, axis=[0, 1]) base_layer.add_summary('k_means/centroid/avg_cluster_count', jnp.mean(per_cluster_count)) # Sum of the input per each closest centroid. sum_x = jnp.einsum('BLNK, BLNH -> NKH', nearest_one_hot, inputs) # Sum up cluster counts across replicas. # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that # cluster's position will always be 0, hence 'sum_x' in that dimension # will be 0. new_means = sum_x / (p.epsilon + jnp.expand_dims(per_cluster_count, axis=-1)) updated_means = (1.0 - p.decay) * new_means + p.decay * theta.means updated_means = jnp.array(updated_means, self.vars.means.dtype) self.update_var('means', updated_means) return dists, nearest_centroid
def add_grad_norm_summary(key, value): base_layer.add_summary(f'{learner_name}/var_grad_norm/{key}', value)
def scale_gradients(self, raw_grads: NestedMap) -> Tuple[NestedMap, JTensor]: """Scales the gradient. Args: raw_grads: A nested structure of gradient values. Returns: A nested structure with the rescaled gradient values. A predicate tensor indicating whether the step is valid, i.e., it does not have anomaly detected (e.g. Nan or Inf, or excessively big gradient norm) and should not be skipped. """ p = self.params learner_name = self.params.name # Compute gradient norm. grad_squared = jax.tree_map(lambda x: jnp.sum(x * x), raw_grads) if p.grad_norm_individual_vars: grad_norms = jax.tree_map(jnp.sqrt, grad_squared) var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms) def add_grad_norm_summary(key, value): base_layer.add_summary(f'{learner_name}/var_grad_norm/{key}', value) jax.tree_multimap(add_grad_norm_summary, var_keys, grad_norms) grad_squared, _ = jax.tree_flatten(grad_squared) grad_squared = jnp.concatenate([x[jnp.newaxis] for x in grad_squared]) raw_grad_norm = jnp.sqrt(jnp.sum(grad_squared)) base_layer.add_summary(f'{learner_name}/grad_norm', raw_grad_norm) def keep_step(grad_norm): keep_threshold = p.skip_step_gradient_norm_value if keep_threshold: return jnp.logical_and( jnp.all(jnp.isfinite(grad_norm)), jnp.all(jnp.less(grad_norm, keep_threshold))) else: return jnp.all(jnp.isfinite(grad_norm)) def clip_grads(grads, grad_norm): if p.optimizer.clip_gradient_norm_to_value: assert p.optimizer.clip_gradient_single_norm_to_value == 0. grad_scale = jnp.minimum( jnp.array(1, grad_norm.dtype), jnp.array(p.optimizer.clip_gradient_norm_to_value, grad_norm.dtype) / grad_norm) grads = jax.tree_map(lambda g: g * grad_scale, grads) elif p.optimizer.clip_gradient_single_norm_to_value: assert p.optimizer.clip_gradient_norm_to_value == 0. grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), grads) def scale_gradient(grad, norm): return grad * jnp.minimum( jnp.array(1, norm.dtype), jnp.array(p.optimizer.clip_gradient_single_norm_to_value, norm.dtype) / norm) grads = jax.tree_map(scale_gradient, grads, grad_single_norm) grad_scale = jnp.array(1.0) else: # no clipping is needed. grad_scale = jnp.array(1.0) return grads, grad_scale # Mark the step as invalid if any gradient anomaly is detected (e.g. Nan or # Inf, or excessively big gradient norm). valid_step = keep_step(raw_grad_norm) grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) base_layer.add_summary('grad_scale', grad_scale) return grads, valid_step
def compute_and_update_moments( self, inputs: JTensor, paddings: JTensor) -> Tuple[JTensor, JTensor, JTensor, JTensor]: """Computes moments and updates state. Args: inputs: The inputs JTensor. Shaped [..., dim]. paddings: The paddings JTensor. Shaped [..., 1], with the same rank as the input JTensor. Returns: Tuple of (mean, variance, beta, gamma). """ p = self.params theta = self.local_theta() if self.do_eval: # The mean and variance used for normalization. norm_mean, norm_variance = theta.moving_mean, theta.moving_variance base_layer.add_summary('moving_mean', theta.moving_mean) base_layer.add_summary('moving_variance', theta.moving_variance) else: rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) mean, variance = compute_moments(inputs, paddings, reduce_over_dims, enable_cross_replica_sum_on_tpu=p. enable_cross_replica_sum_on_tpu, keepdims=True) new_moving_mean = theta.moving_mean * p.decay + mean * (1.0 - p.decay) self.update_var('moving_mean', new_moving_mean) new_moving_variance = (theta.moving_variance * p.decay + variance * (1.0 - p.decay)) self.update_var('moving_variance', new_moving_variance) # Add some summaries for visualization. base_layer.add_summary('mean', mean) base_layer.add_summary('variance', variance) base_layer.add_summary('moving_mean', theta.moving_mean) base_layer.add_summary('moving_variance', theta.moving_variance) if p.use_moving_avg_in_training: # Use the global statistics for normalization. norm_mean = theta.moving_mean norm_variance = theta.moving_variance else: # Use the batch statistics for normalization. norm_mean = mean norm_variance = variance beta, gamma = self._get_beta_gamma() return norm_mean, norm_variance, beta, gamma