def _Pack(self, batch): """Packs a given batch. Note that this may change the batch size. This function packs the input batch and adds .segment_ids and .segment_pos fields to its `src` and `tgt` fields. Args: batch: a `.NestedMap` of input tensors to be packed. It is modified in place. """ src_actual_seq_len = tf.math.reduce_sum( tf.cast(batch.src.ids_indicator, tf.int32), axis=1) tgt_actual_seq_len = tf.math.reduce_sum( tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1) summary_utils.histogram('source_seq_lengths', src_actual_seq_len) summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len) if not self.params.packing_factor: # Supply segment_ids and segment_pos with no packing. batch.src.segment_ids = batch.src.ids_indicator batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator) batch.tgt.segment_ids = batch.tgt.ids_indicator batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator) return (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids, tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences( src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(), self.params.source_max_length, self.params.target_max_length) uniq_src_indices_in_input = tf.unique( tf.reshape(src_indices_in_input, [-1])).y uniq_tgt_indices_in_input = tf.unique( tf.reshape(tgt_indices_in_input, [-1])).y summary_utils.histogram( 'packed_source_seq_lengths', tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0)) summary_utils.histogram( 'packed_target_seq_lengths', tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0)) # We deferred adding .paddings and use its complement .ids_indicator # exclusively so that we can apply the packing with padding set to 0 for all # fields. def ApplyPackingToSource(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input) return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input) batch.src = batch.src.Transform(ApplyPackingToSource) batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32) batch.src.segment_pos = src_segment_pos def ApplyPackingToTarget(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input) return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input) batch.tgt = batch.tgt.Transform(ApplyPackingToTarget) batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32) batch.tgt.segment_pos = tgt_segment_pos
def ComputeAndUpdateMoments(self, theta, inputs, paddings=None, **kwargs): """Computes moments and updates state. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. **kwargs: Additional inputs. Returns: Tuple of (mean, variance, beta, gamma). """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]), ], inputs) with tf.name_scope(p.name): if self.do_eval or p.freeze_bn_stats: # The mean and variance used for normalization. norm_mean, norm_variance = (self.vars.moving_mean, self.vars.moving_variance) else: rank = tf.rank(paddings) reduce_over_dims = tf.range(0, rank - 1) mean, variance = ComputeMomentsWithPadding( inputs, paddings, reduce_over_dims, None, p.enable_cross_replica_sum_on_tpu) py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean, self._decay) py_utils.UpdateBatchNormVars(self.vars.moving_variance, variance, self._decay) # Add some summaries for visualization. summary_utils.histogram('%s_mean' % p.name, tf.cast(mean, tf.float32)) summary_utils.histogram('%s_variance' % p.name, tf.cast(variance, tf.float32)) summary_utils.histogram( '%s_moving_mean' % p.name, tf.cast(self.vars.moving_mean, tf.float32)) summary_utils.histogram( '%s_moving_variance' % p.name, tf.cast(self.vars.moving_variance, tf.float32)) summary_utils.histogram( '%s_mean_diff' % p.name, tf.cast( tf.cast(mean, self.vars.moving_mean.dtype.base_dtype) - self.vars.moving_mean, tf.float32)) summary_utils.histogram( '%s_variance_diff' % p.name, tf.cast( tf.cast(variance, self.vars.moving_variance.dtype.base_dtype) - self.vars.moving_variance, tf.float32)) if p.use_moving_avg_in_training: # Use the global statistics for normalization. # Control dependencies on mean and variance make sure # moving_mean and variance will be updated for every training step. norm_mean = py_utils.with_dependencies( [mean], self.vars.moving_mean) norm_variance = py_utils.with_dependencies( [variance], self.vars.moving_variance) else: # Use the batch statistics for normalization. norm_mean = mean norm_variance = variance norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) beta, gamma = self._GetBetaGamma(theta, inputs, **kwargs) return norm_mean, norm_variance, beta, gamma
def FProp(self, theta, x, paddings=None, update=False): """Computes distances of the given input 'x' to all centroids. This implementation applies layer normalization on 'x' internally first, and the returned 'dists' is computed using the normalized 'x'. Args: theta: A `.NestedMap` of weights' values of this layer. x: A tensor of shape [B, L, N, H]. paddings: If not None, a tensor of shape [B, L]. update: bool, whether to update centroids using x. Returns: dists: "distances" of the given input 'x' to all centroids. Shape [B, L, N, K]. k_means_loss: the average squared Euclidean distances to the closest centroid, a scalar. """ p = self.params x = tf.cast(x, theta.means.dtype) if paddings is None: paddings = tf.zeros_like(x[:, :, 0, 0]) # Shape [B, L, 1, 1] paddings_4d = paddings[:, :, None, None] if p.apply_layer_norm: x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon) # 'x' is normalized (but theta.means is not), we use negative dot product to # approximate the Euclidean distance here. dists = -2 * tf.einsum('BLNH, NKH -> BLNK', x, theta.means) if not p.apply_layer_norm: # If entries are not normalized, compute norms here. x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True) means_norm_sq = tf.reduce_sum(tf.square(theta.means), axis=-1, keepdims=False) means_norm_sq = tf.expand_dims(means_norm_sq, axis=0) means_norm_sq = tf.expand_dims(means_norm_sq, axis=0) dists += x_norm_sq + means_norm_sq # For padded positions we update the distances to very large numbers. very_large_dists = tf.ones_like(dists) * tf.constant( 0.1, dtype=dists.dtype) * dists.dtype.max paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters]) dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists) # Shape [B, L, N, K], the same as 'dists' above. nearest_one_hot = tf.one_hot(tf.math.argmin(dists, axis=-1), p.num_clusters, dtype=theta.means.dtype) # Same shape as the input 'x'. nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot, theta.means) diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid)) diff = py_utils.ApplyPadding(paddings_4d, diff) diff = tf.math.reduce_mean(diff, axis=2) # The commitment loss which when back proped against encourages the 'x' # values to commit to their chosen centroids. diff = tf.cast(diff, tf.float32) paddings = tf.cast(paddings, tf.float32) k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings) summary_utils.scalar('k_means/squared_distance_loss', k_means_loss) # TODO(zhouwk): investigate normalizing theta.means after each update. means_norm = tf.norm(theta.means) summary_utils.scalar('k_means/centroid_l2_norm/min', tf.math.reduce_min(means_norm)) summary_utils.scalar('k_means/centroid_l2_norm/mean', tf.math.reduce_mean(means_norm)) if not update: return dists, k_means_loss # To update the centroids (self.vars.means), we apply gradient descent on # the mini-batch of input 'x', which yields the following: # new_centroid = centroid + (1 - decay) * (x_mean - centroid) # where x_mean is the average over all the input vectors closest to this # centroid. # # Note that this approach is equivalent with backprop via # loss = tf.math.reduce_mean( # tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid))) # , except that here the learning rate is independently set via 'decay'. # Ensure that the padded positions are not used to update the centroids. nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot) # Sum away batch and sequence length dimensions to get per cluster count. # Shape: [N, K] per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1]) summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count) # Sum of the input 'x' per each closest centroid. sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x) if py_utils.use_tpu(): per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count) sum_x = tf.tpu.cross_replica_sum(sum_x) if p.use_ema: updated_ema_count = moving_averages.assign_moving_average( self.vars.ema_count, tf.cast(per_cluster_count, self.vars.ema_count.dtype), p.decay, zero_debias=False) updated_ema_means = moving_averages.assign_moving_average( self.vars.ema_means, tf.cast(sum_x, self.vars.ema_means.dtype), p.decay, zero_debias=False) n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True) updated_ema_count = ((updated_ema_count + p.epsilon) / (n + p.num_clusters * p.epsilon) * n) updated_ema_means = updated_ema_means / tf.expand_dims( updated_ema_count, axis=-1) updated_ema_means = tf.cast(updated_ema_means, self.vars.means.dtype) means = tf.cast(theta.means, updated_ema_means.dtype) update_means_diff = updated_ema_means - means else: # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that # cluster's position will always be 0, hence 'sum_x' in that dimension # will be 0. new_means = sum_x / tf.maximum( tf.constant(1.0, dtype=per_cluster_count.dtype), tf.expand_dims(per_cluster_count, axis=-1)) # Note that we intentionally do not normalize the means after this update # as empirically this works better. update_means_diff = tf.cast( (1.0 - p.decay) * (new_means - theta.means), self.vars.means.dtype) return py_utils.with_dependencies( [tf.assign_add(self.vars.means, update_means_diff)], dists), k_means_loss
def ComputePredictions(self, theta, source_encs, source_paddings, targets, src_segment_id): """Decodes `targets` given encoded source. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_encs: source encoding, of shape [time, batch, depth]. source_paddings: source encoding's padding, of shape [time, batch]. targets: A dict of string to tensors representing the targets one try to predict. Each tensor in targets is of shape [batch, time]. src_segment_id: source segment id, of shape [time, batch]. Returns: A Tensor with shape [time, batch, params.softmax.input_dim]. """ p = self.params time, batch = py_utils.GetShape(source_paddings, 2) source_encs = py_utils.HasShape(source_encs, [time, batch, p.source_dim]) with tf.name_scope(p.name): target_ids = tf.transpose(targets.ids) target_paddings = py_utils.HasRank(targets.paddings, 2) target_paddings = tf.expand_dims(tf.transpose(target_paddings), 2) if p.packed_input: target_segment_id = tf.expand_dims(tf.transpose(targets.segment_ids), 2) else: target_segment_id = tf.zeros_like(target_paddings) if py_utils.use_tpu(): emb_device = self.cluster.WorkerDeviceInModelSplit(0) else: emb_device = '' with tf.device(emb_device): inputs = self.emb.EmbLookup(theta.emb, target_ids) inputs = self.ApplyClipping(theta, inputs) summary_utils.histogram('input_emb', inputs) inputs = self.ApplyDropout(inputs) self._emb_out = inputs # Layer 0 interwines with attention. (atten_ctxs, xs, atten_probs, _) = self.frnn_with_atten.FProp( theta.frnn_with_atten, source_encs, source_paddings, inputs, target_paddings, src_segment_id=src_segment_id, segment_id=target_segment_id) self._AddAttenProbsSummary(source_paddings, targets, [atten_probs]) atten_ctxs = self.ApplyClipping(theta, atten_ctxs) summary_utils.histogram('atten_ctxs', atten_ctxs) for i, (layer, layer_theta) in enumerate(zip(self.frnn, theta.frnn)): # Forward through Layer-(i + 1) because Layer-0 handled before. ys, _ = layer.FProp( layer_theta, tf.concat([xs, atten_ctxs], 2), target_paddings, segment_id=target_segment_id) ys = self.ApplyDropout(ys) if 1 + i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys summary_utils.histogram('layer_out_%s' % i, xs) if p.feed_attention_context_vec_to_softmax: xs = tf.concat([xs, atten_ctxs], 2) return xs
def _AddAttenProbsHistogramSummary(self, name, atten_probs): """Add histogram summary of attention probs.""" summary_utils.histogram(name + '/atten_probs', atten_probs)
def ComputeAndUpdateMoments(self, theta, inputs, paddings=None): """Computes moments and updates state. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Tuple of (mean, variance, beta, gamma). """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(inputs)[-1]], [p.dim]), py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]), ], inputs) with tf.name_scope(p.name): if p.is_eval: # The mean and variance used for normalization. norm_mean, norm_variance = self._moving_mean, self._moving_variance else: mean, variance = self._Moments(inputs, 1.0 - paddings, p.enable_cross_replica_sum_on_tpu) py_utils.UpdateBatchNormVars(self._moving_mean, mean, self._decay) py_utils.UpdateBatchNormVars(self._moving_variance, variance, self._decay) # Add some summaries for visualization. summary_utils.histogram('%s_mean' % p.name, tf.cast(mean, tf.float32)) summary_utils.histogram('%s_variance' % p.name, tf.cast(variance, tf.float32)) summary_utils.histogram('%s_moving_mean' % p.name, tf.cast(self._moving_mean, tf.float32)) summary_utils.histogram('%s_moving_variance' % p.name, tf.cast(self._moving_variance, tf.float32)) summary_utils.histogram('%s_mean_diff' % p.name, tf.cast(mean - self._moving_mean, tf.float32)) summary_utils.histogram( '%s_variance_diff' % p.name, tf.cast(variance - self._moving_variance, tf.float32)) if p.use_moving_avg_in_training: # Use the global statistics for normalization. # Control dependencies on mean and variance make sure # moving_mean and variance will be updated for every training step. norm_mean = py_utils.with_dependencies([mean], self._moving_mean) norm_variance = py_utils.with_dependencies([variance], self._moving_variance) else: # Use the batch statistics for normalization. norm_mean = mean norm_variance = variance norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) if p.use_moving_avg_in_training: beta = 0.0 gamma = 1.0 else: beta = theta.beta gamma = theta.gamma return norm_mean, norm_variance, beta, gamma
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) summary_utils.histogram('input_embs', input_embs) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) summary_utils.histogram('emb_proj', input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def _ApplyPacking(self, batch): """Packs a given batch. Note that this may change the batch size. This function packs the input batch and adds .segment_ids and .segment_pos fields to its `src` and `tgt` fields. Args: batch: a `.NestedMap` of input tensors to be packed. It is modified in place. """ src_actual_seq_len = tf.math.reduce_sum( tf.cast(batch.src.ids_indicator, tf.int32), axis=1) tgt_actual_seq_len = tf.math.reduce_sum( tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1) summary_utils.histogram('source_seq_lengths', src_actual_seq_len) summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len) if not self.params.packing_factor: # Supply segment_ids and segment_pos with no packing. batch.src.segment_ids = batch.src.ids_indicator batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator) batch.tgt.segment_ids = batch.tgt.ids_indicator batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator) return (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids, tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences( src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(), self.params.source_max_length, self.params.target_max_length) uniq_src_indices_in_input = tf.unique( tf.reshape(src_indices_in_input, [-1])).y uniq_tgt_indices_in_input = tf.unique( tf.reshape(tgt_indices_in_input, [-1])).y summary_utils.histogram( 'packed_source_seq_lengths', tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0)) summary_utils.histogram( 'packed_target_seq_lengths', tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0)) # Ratio of number of non-padded tokens. If < 1.0, we are dropping # input data due to p.packing_factor too high. src_orig_tokens_count = tf.cast( tf.reduce_sum(src_actual_seq_len), tf.float32) src_packed_tokens_count = tf.reduce_sum( tf.cast(src_segment_ids > 0, tf.float32)) summary_utils.scalar('examples/src_packed_token_ratio', src_packed_tokens_count / src_orig_tokens_count) tgt_orig_tokens_count = tf.cast( tf.reduce_sum(tgt_actual_seq_len), tf.float32) tgt_packed_tokens_count = tf.reduce_sum( tf.cast(tgt_segment_ids > 0, tf.float32)) summary_utils.scalar('examples/tgt_packed_token_ratio', tgt_packed_tokens_count / tgt_orig_tokens_count) # We deferred adding .paddings and use its complement .ids_indicator # exclusively so that we can apply the packing with padding set to 0 for all # fields. def ApplyPackingToSource(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input) return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input) src_paddings = ops.apply_packing(batch.src.paddings, 1, src_segment_ids, src_indices_in_input) batch.src = batch.src.Transform(ApplyPackingToSource) batch.src.paddings = src_paddings batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32) batch.src.segment_pos = src_segment_pos def ApplyPackingToTarget(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input) return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input) tgt_paddings = ops.apply_packing(batch.tgt.paddings, 1, tgt_segment_ids, tgt_indices_in_input) batch.tgt = batch.tgt.Transform(ApplyPackingToTarget) batch.tgt.paddings = tgt_paddings batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32) batch.tgt.segment_pos = tgt_segment_pos # The number of examples is indicated by the segment_ids of the target. num_segments = tf.math.reduce_max(batch.tgt.segment_ids, axis=1) num_examples = tf.reduce_sum(num_segments) # Note that this is per infeed value when p.use_per_host_infeed = True. metric_name = 'examples/num_packed_examples' summary_utils.scalar(metric_name, num_examples)