def _TruncateTargetSequence(self, targets): """Truncate padded time steps from all sequences.""" # The following tensors are all in the [batch, time] shape. p = self.params # Let's make a copy of targets. targets = targets.Pack(targets.Flatten()) target_ids = targets.ids target_labels = targets.labels target_weights = targets.weights target_paddings = targets.paddings max_seq_length = tf.to_int32( tf.reduce_max(tf.reduce_sum(1.0 - target_paddings, 1))) summary_utils.scalar('max_seq_length', max_seq_length) # Assert to make sure after max_seq_length, all are padded steps for all # sequences. target_paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all(target_paddings[:, max_seq_length:] > 0.5)) ], target_paddings) target_ids = py_utils.with_dependencies([ AssertIdShape( py_utils.GetShape(target_ids), py_utils.GetShape(target_labels), py_utils.GetShape(target_paddings), py_utils.GetShape(target_weights)) ], target_ids) targets.ids = target_ids[:, :max_seq_length] targets.labels = target_labels[:, :max_seq_length] targets.weights = target_weights[:, :max_seq_length] targets.paddings = target_paddings[:, :max_seq_length] return targets
def ScaleGradients(self, var_grads): """Scales gradients according to training params. Args: var_grads: a `.NestedMap` whose values are (var, grad) pairs. Returns: (has_nan_or_inf, grad_scale, final_var_grads). - has_nan_or_inf: a scalar of 0 or 1, indicating whether there is any NaN or Inf in input gradients. - grad_scale: the gradient scale. 0 if gradient updates should be skipped for the step. - final_var_grads: a `.NestedMap` whose values are (var, grad) pairs, where gradients have already been scaled. """ p = self.params tp = p.train # Computes gradients' norm and adds their summaries. Note that all_grad_norm # may be nan, which may cause grad_scale to be nan. for name, vg in var_grads.FlattenItems(): summary_utils.AddNormSummary(p, name, py_utils.NestedMap(s=vg)) _, all_grad_norm = summary_utils.AddNormSummary(p, 'all', var_grads) grad_norm_is_nan_or_inf = tf.logical_or(tf.is_nan(all_grad_norm), tf.is_inf(all_grad_norm)) # Optional gradient adjustment. Note that this happens after computing # all_grad_norm. var_grads = self.AdjustGradients(var_grads) # Handles NaN/Inf gradients. has_nan_or_inf = self._HasNanOrInf(var_grads) # Grad norm can still be inf even if none of the individual grad is inf. has_nan_or_inf = tf.logical_or(has_nan_or_inf, grad_norm_is_nan_or_inf) # Computes gradient's scale. grad_scale = tf.constant(1.0) if tp.clip_gradient_norm_to_value: # If all_grad_norm > tp.clip_gradient_norm_to_value, scales # all_grads so that the norm is 1.0. grad_scale = tf.minimum( 1.0, tp.clip_gradient_norm_to_value / all_grad_norm) if tp.grad_norm_to_clip_to_zero: # If all_grad_norm > tp.grad_norm_to_clip_to_zero, treats # grad_scale as 0. This way, we ignore this step. grad_scale *= tf.cast(all_grad_norm < tp.grad_norm_to_clip_to_zero, p.dtype) if tp.grad_norm_tracker: grad_scale *= self.grad_norm_tracker.FPropDefaultTheta( all_grad_norm, has_nan_or_inf) # Force grad_scale to be 0 if there is any NaN or Inf in gradients. grad_scale = tf.where(has_nan_or_inf, 0.0, grad_scale) summary_utils.scalar(p, 'grad_scale_all', grad_scale) final_var_grads = py_utils.ApplyGradMultiplier(var_grads, grad_scale) return has_nan_or_inf, grad_scale, final_var_grads
def AddSummary(self, lr, optimizer, var_grad): summary_utils.scalar('adagrad_lr', lr) for v, _ in var_grad.Flatten(): slot = optimizer.get_slot(v, 'accumulator') assert slot is not None summary_utils.scalar('optimizer/adagrad_accum_%s' % v.name, tf.reduce_mean(slot))
def IncBy(self, params, delta): """Increment the counter by delta and return the new value.""" # NOTE: We must ensure _value is computed (_var + 0) before # updating _var with delta. delta = tf.to_int64(delta) with tf.control_dependencies([self._value]): summary_utils.scalar(params, self._name, self._value) return tf.identity(tf.assign_add(self._var, delta))
def _SummarizeTensor(self, t_name): min_var = self._GetQStateVar(t_name, 'min') max_var = self._GetQStateVar(t_name, 'max') # foo/q/somet_min:0 -> foo/q/somet_min summary_name_min = min_var.name.split(':')[0] summary_name_max = max_var.name.split(':')[0] summary_utils.scalar(summary_name_min, min_var) summary_utils.scalar(summary_name_max, max_var)
def _FPropMetrics(self, metrics): # Adds stats about the input batch. metrics['num_samples_in_batch'] = (tf.convert_to_tensor( self.input_generator.InputBatchSize()), tf.constant(1.0)) # Generates summaries. for name, (value, weight) in six.iteritems(metrics): self.AddEvalMetric(name, value, weight) # Loss. self._loss, self._num_predictions = metrics['loss'] self._loss = py_utils.CheckNumerics(self._loss) summary_utils.scalar('num_predictions', self._num_predictions)
def CreateTaskGlobalStep(params, task_name): """Create if needed and return the global_step.""" with tf.name_scope(None), tf.variable_scope(py_utils.global_variable_scope): graph_collections = [tf.GraphKeys.GLOBAL_VARIABLES, 'TASK_GLOBAL_STEP'] _, v = py_utils.CreateVariable( name=task_name + '_global_step', params=py_utils.WeightParams([], py_utils.WeightInit.Constant(0), tf.int64), trainable=False, collections=graph_collections) summary_utils.scalar(params, v.name, v) return v
def _FPropResult(self, metrics, per_example): # Adds stats about the input batch. metrics['num_samples_in_batch'] = (tf.convert_to_tensor( self.input_generator.GlobalBatchSize()), tf.constant(1.0)) # Generates summaries. for name, (value, weight) in six.iteritems(metrics): self.AddEvalMetric(name, value, weight) per_example = self.FilterPerExampleTensors(per_example) for name, value in six.iteritems(per_example): self.AddPerExampleTensor(name, value) # Loss. self._loss, self._num_predictions = metrics['loss'] self._loss = py_utils.CheckNumerics(self._loss) self._metrics = metrics summary_utils.scalar('num_predictions', self._num_predictions)
def _InputBatch(self): """Returns tf.data.Dataset of unbatched NestedMap.""" p = self.params dataset = self._DatasetOfExamples() dataset = dataset.map(self._ProcessSingleExample, num_parallel_calls=tf.data.AUTOTUNE).unbatch() dataset = dataset.take(p.num_samples if p.num_samples > 0 else -1) if p.enable_packing: dataset = dataset.batch( p.prepacking_batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE if p.shuffle else 1).map( self._Pack, num_parallel_calls=tf.data.AUTOTUNE).unbatch() dataset = dataset.batch( self.InfeedBatchSize(), drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE if p.shuffle else 1, ) dataset = dataset.prefetch(tf.data.AUTOTUNE) iterator = tf.data.make_initializable_iterator(dataset) self._initializer = iterator.initializer batch = iterator.get_next() if not hasattr(batch, 'segment_ids'): # Supply segment_ids and segment_pos with no packing. batch.segment_ids = 1.0 - batch.paddings segpos = tf.cast(tf.range(p.source_max_length), dtype=tf.float32) batch.segment_pos = tf.cast(batch.segment_ids * segpos, dtype=tf.int32) def ShapeAndCast(x): x = tf.ensure_shape(x, (self.InfeedBatchSize(), p.source_max_length)) if x.dtype.is_floating: x = tf.cast(x, py_utils.FPropDtype(p)) return x batch = batch.Transform(ShapeAndCast) num_samples = tf.math.reduce_max(batch.segment_ids, axis=1) summary_utils.scalar('examples/num_packed_samples', tf.reduce_sum(num_samples)) return batch
def PostTrainingStepUpdate(self, global_step): """Update the cap value.""" p = self.params if p.is_inference: return # Calculations/vars need to be float but these can be ints in the params. clip_end_step = tf.cast(p.clip_end_step, tf.float32) clip_start_step = tf.cast(p.clip_start_step, tf.float32) quant_start_step = tf.cast(p.quant_start_step, tf.float32) global_step = tf.cast(global_step, tf.float32) # Will be negative if before clipping starts. new_clip_ratio = (tf.minimum(clip_end_step - clip_start_step, global_step - clip_start_step) / (clip_end_step - clip_start_step)) # Currently fq is either on (1.0) or off (-1.0). Progressive quantization # may later occupy 0..1.0. new_fq_ratio = tf.where(global_step < quant_start_step, -1.0, 1.0) summary_utils.scalar(p, 'clip_ratio', new_clip_ratio) summary_utils.scalar(p, 'fq_ratio', new_fq_ratio) return tf.group(self.vars.clip_ratio.assign(new_clip_ratio), self.vars.fq_ratio.assign(new_fq_ratio))
def AddSummary(self, lr, optimizer, var_grad): summary_utils.scalar('distributed_shampoo', lr)
def _AddScalarSummary(self, key, value): summary_utils.scalar('%s/%s' % (key, self.params.name), value)
def PostTrainingStepUpdate(self, global_step): summary_utils.scalar('cap', self._Value(global_step)) return tf.no_op()
def _BPropForVariables(self, vmap): """Constructs the backward graph for the given variables. Args: vmap: a `.NestedMap` of variables. """ p = self.params tp = p.train # Compute gradients. self._var_grads = py_utils.ComputeGradients(self.loss, vmap) # L2 regularizer. if tp.l2_regularizer_weight is not None: l2_loss, self._var_grads = py_utils.AdjustGradientsWithLpLoss( self._var_grads, tp.l2_regularizer_weight, p=2.0) summary_utils.scalar(p, 'l2_loss', l2_loss) # L1 regularizer. if tp.l1_regularizer_weight is not None: l1_loss, self._var_grads = py_utils.AdjustGradientsWithLpLoss( self._var_grads, tp.l1_regularizer_weight, p=1.0) summary_utils.scalar(p, 'l1_loss', l1_loss) # Mask gradients only if the mask is set. if self._per_input_gradient_mask: bprop_onehot = self.input_generator.GetInputSourceOneHot() self._var_grads = py_utils.MaskGradients( self._var_grads, self._per_input_gradient_mask, bprop_onehot) # Apply gradient clipping. has_nan_or_inf, _, self._var_grads = self.ScaleGradients(self._var_grads) # Histogram summary. summary_utils.CollectVarHistogram(p, self._var_grads) lrs = self.lr_schedule.Value(self._global_step) summary_utils.scalar(p, 'lr_schedule', lrs) lr = tp.learning_rate * lrs var_update_op = self.optimizer.Apply(lr, self._var_grads) increment_global_step_ops = [] with tf.colocate_with(self._shared_global_step): increment_global_step_ops.append( tf.assign_add(self._shared_global_step, 1)) if self._task_global_step: with tf.colocate_with(self._task_global_step): increment_global_step_ops.append( tf.assign_add(self._task_global_step, 1)) increment_global_steps = tf.group(*increment_global_step_ops) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( self.loss, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) batch_norm_updates = tf.group(*relevant_bn_updates) # Update stats. stats_updates = tf.group( self.IncrementTotalSamples(), self.IncrementTotalNans(tf.to_int32(has_nan_or_inf))) # Post training step update. post_training_step_updates = self.PostTrainingStepUpdate(self._global_step) # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. # if p.contiguous: var_update_op = tf.group(var_update_op, self.last_state_group_op) self._train_op = tf.group( var_update_op, batch_norm_updates, stats_updates, post_training_step_updates, increment_global_steps, mask_update_op, name='train')
def __init__(self, params): assert issubclass(params.cls, BaseTask) # Ensure global_step exists before calling super. py_utils.GetOrCreateGlobalStepVar() super().__init__(params) p = self.params self._encoder = None self._online_encoder = None self._decoder = None self._loss = None self._num_predictions = None self._train_op = None self._post_train_ops = [] self._eval_metrics = {} self._per_example = {} # Create the gradient mask, self._per_input_gradient_mask = None if p.task_global_step: with tf.name_scope(None), tf.variable_scope( py_utils.GetGlobalVariableScope()): var_name = p.name + '_global_step' # Create the variable immediately. self._CreateVariableInternal( var_name, base_layer.CreateVariableMeta( var_params=py_utils.WeightParams( [], py_utils.WeightInit.Constant(0), tf.int64), theta_fn=None, kwargs=dict( trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]))) summary_utils.scalar(var_name, self._private_vars[var_name]) self._global_step_var = self._private_vars[var_name] else: self._global_step_var = py_utils.GetOrCreateGlobalStepVar() if p.input: # TODO(zhifengc): Consider a simpler way to ensure the input # generator stops after one epoch. if self.do_eval and p.eval: seq_inp = issubclass(p.input.cls, base_input_generator.BaseInputGeneratorFromFiles) if p.input.num_samples == 0: # Dataset size is unknown. Computes eval summary based on num_samples. assert p.eval.samples_per_summary > 0 elif (p.eval.samples_per_summary == 0) or (p.input.num_samples < p.eval.samples_per_summary): # If we know the dataset size and we want to evaluate the full # set, we need to coordinate the input generator to flush out # all samples so the evaler and decoder compute metrics on the # whole set for each summary step. if seq_inp: p.input.flush_every_n = p.input.num_samples p.eval.samples_per_summary = p.input.num_samples if seq_inp and p.input.num_batcher_threads > 1: tf.logging.warning( 'input.num_batcher_threads > 1 inside eval mode. ' 'The input generator may not iterate over exactly ' 'one epoch per run') tf.logging.info('input_params: %s', p.input) input_params = self.cluster.PlaceInput(p.input) # For TPU training, we create the input generator in a # different scope and AddChild it in later. if 'skip_create_child' not in p.input: self.CreateChild('input', input_params) tp = p.train # p.train can be None if this task is the teacher/student task in a # DistillationTask. if tp: self._SetLearnerFromLegacyParams(tp) if tp.learner is not None: if isinstance(tp.learner, (list, tuple)): self.CreateChildren('learners', tp.learner) else: self.CreateChildren('learners', [tp.learner]) self._UpdateVnConfig()
def FProp(self, theta, x, paddings=None, update=False): """Computes distances of the given input 'x' to all centroids. This implementation applies layer normalization on 'x' internally first, and the returned 'dists' is computed using the normalized 'x'. Args: theta: A `.NestedMap` of weights' values of this layer. x: A tensor of shape [B, L, N, H]. paddings: If not None, a tensor of shape [B, L]. update: bool, whether to update centroids using x. Returns: dists: "distances" of the given input 'x' to all centroids. Shape [B, L, N, K]. k_means_loss: the average squared Euclidean distances to the closest centroid, a scalar. """ p = self.params if paddings is None: paddings = tf.zeros_like(x[:, :, 0, 0]) # Shape [B, L, 1, 1] paddings_4d = paddings[:, :, None, None] if p.apply_layer_norm: x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon) # 'x' is normalized (but theta.means is not), we use negative dot product to # approximate the Euclidean distance here. dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means) # For padded positions we update the distances to very large numbers. very_large_dists = tf.ones_like(dists) * tf.constant( 0.1, dtype=dists.dtype) * dists.dtype.max paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters]) dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists) # Shape [B, L, N, K], the same as 'dists' above. nearest_one_hot = tf.one_hot( tf.math.argmin(dists, axis=-1), p.num_clusters, dtype=py_utils.FPropDtype(p)) # Same shape as the input 'x'. nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot, theta.means) diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid)) diff = py_utils.ApplyPadding(paddings_4d, diff) diff = tf.math.reduce_mean(diff, axis=2) # The commitment loss which when back proped against encourages the 'x' # values to commit to their chosen centroids. k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings) summary_utils.scalar('k_means/squared_distance_loss', k_means_loss) # TODO(zhouwk): investigate normalizing theta.means after each update. means_norm = tf.norm(theta.means) summary_utils.scalar('k_means/centroid_l2_norm/min', tf.math.reduce_min(means_norm)) summary_utils.scalar('k_means/centroid_l2_norm/mean', tf.math.reduce_mean(means_norm)) if not update: return dists, k_means_loss # To update the centroids (self.vars.means), we apply gradient descent on # the mini-batch of input 'x', which yields the following: # new_centroid = centroid + (1 - decay) * (x_mean - centroid) # where x_mean is the average over all the input vectors closest to this # centroid. # # Note that this approach is equivalent with backprop via # loss = tf.math.reduce_mean( # tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid))) # , except that here the learning rate is independently set via 'decay'. # Ensure that the padded positions are not used to update the centroids. nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot) # Sum away batch and sequence length dimensions to get per cluster count. # Shape: [N, K] per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1]) summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count) # Sum of the input 'x' per each closest centroid. sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x) if py_utils.use_tpu(): per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count) sum_x = tf.tpu.cross_replica_sum(sum_x) # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that # cluster's position will always be 0, hence 'sum_x' in that dimension will # be 0. new_means = sum_x / tf.maximum( tf.constant(1.0, dtype=per_cluster_count.dtype), tf.expand_dims(per_cluster_count, axis=-1)) # We use exponential moving average. TODO(zhouwk): investigate smooth this # over an exponentially moving averaged per cluster count. # # Note that we intentionally do not normalize the means after this update # as empirically this works better. update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means), self.vars.means.dtype) return py_utils.with_dependencies( [tf.assign_add(self.vars.means, update_means_diff)], dists), k_means_loss
def AddSummary(self, lr, optimizer, var_grad): summary_utils.scalar(self.params, 'adam_lr', lr)
def ScaleGradients(self, var_grads, gradient_adjuster=None): """Scales gradients according to training params. Args: var_grads: a `.NestedMap` whose values are (var, grad) pairs. gradient_adjuster: if not None, a function that mutates a given var_grads. Returns: A `.NestedMap` containing - final_var_grads: a `.NestedMap` whose values are (var, grad) pairs, where gradients have already been scaled. - grad_scale: the gradient scale. 0 if gradient updates should be skipped for the step. (Optional, only returned in case global norm clipping is used.) """ p = self.params # Computes gradients' norm and adds their summaries. Note that all_grad_norm # may be nan, which may cause grad_scale to be nan. for name, vg in var_grads.FlattenItems(): summary_utils.AddNormSummary( py_utils.SanitizeScopeKey(name) + '/' + p.name, vg) flatten = py_utils.Flatten(var_grads) all_grad_norm = tf.sqrt(py_utils.SumSquared([g for (_, g) in flatten])) all_var_norm = tf.sqrt(py_utils.SumSquared([v for (v, _) in flatten])) grad_norm_is_nan_or_inf = tf.math.logical_or( tf.math.is_nan(all_grad_norm), tf.math.is_inf(all_grad_norm)) # Optional gradient adjustment. Note that this happens after computing # all_grad_norm. if gradient_adjuster is not None: tf.logging.info('gradient_adjuster=%s', gradient_adjuster) var_grads = gradient_adjuster(var_grads) # Handles NaN/Inf gradients. has_nan_or_inf = py_utils.HasNanOrInfGradient(var_grads) # Grad norm can still be inf even if none of the individual grad is inf. has_nan_or_inf = tf.math.logical_or(has_nan_or_inf, grad_norm_is_nan_or_inf) self._AddEvalMetric('has_nan_or_inf', has_nan_or_inf, tf.constant(1.0)) return_values = py_utils.NestedMap() if p.clip_gradient_single_norm_to_value: # Currently using both types of clipping simultaneously is unsupported. if p.clip_gradient_norm_to_value: raise ValueError( 'Cannot use clip_gradient_single_norm_to_value=%f and ' 'clip_gradient_norm_to_value=%f.' % (p.clip_gradient_single_norm_to_value, p.clip_gradient_norm_to_value)) final_var_grads = py_utils.ApplyGradNormClipping( var_grads, p.clip_gradient_single_norm_to_value) else: grad_scale = self._GetGlobalGradScale(all_grad_norm, has_nan_or_inf) # grad_norm/all is both a eval metric(collected by trainer) and a summary # (collected by controller). summary_utils.scalar(f'grad_norm/all/{p.name}', all_grad_norm) self._AddEvalMetric('grad_norm/all', all_grad_norm, tf.constant(1.0)) self._AddEvalMetric('var_norm/all', all_var_norm, tf.constant(1.0)) self._AddEvalMetric('grad_scale_all', grad_scale, tf.constant(1.0)) final_var_grads = py_utils.ApplyGradMultiplier( var_grads, grad_scale) return_values.grad_scale = grad_scale return_values.final_var_grads = final_var_grads return return_values
def _ApplyPacking(self, batch): """Packs a given batch. Note that this may change the batch size. This function packs the input batch and adds .segment_ids and .segment_pos fields to its `src` and `tgt` fields. Args: batch: a `.NestedMap` of input tensors to be packed. It is modified in place. """ src_actual_seq_len = tf.math.reduce_sum(tf.cast( batch.src.ids_indicator, tf.int32), axis=1) tgt_actual_seq_len = tf.math.reduce_sum(tf.cast( batch.tgt.ids_indicator, tf.int32), axis=1) summary_utils.histogram('source_seq_lengths', src_actual_seq_len) summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len) if not self.params.packing_factor: # Supply segment_ids and segment_pos with no packing. batch.src.segment_ids = batch.src.ids_indicator batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator) batch.tgt.segment_ids = batch.tgt.ids_indicator batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator) return (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids, tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences( src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(), self.params.source_max_length, self.params.target_max_length) uniq_src_indices_in_input = tf.unique( tf.reshape(src_indices_in_input, [-1])).y uniq_tgt_indices_in_input = tf.unique( tf.reshape(tgt_indices_in_input, [-1])).y summary_utils.histogram( 'packed_source_seq_lengths', tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0)) summary_utils.histogram( 'packed_target_seq_lengths', tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0)) # Ratio of number of non-padded tokens. If < 1.0, we are dropping # input data due to p.packing_factor too high. src_orig_tokens_count = tf.cast(tf.reduce_sum(src_actual_seq_len), tf.float32) src_packed_tokens_count = tf.reduce_sum( tf.cast(src_segment_ids > 0, tf.float32)) summary_utils.scalar('examples/src_packed_token_ratio', src_packed_tokens_count / src_orig_tokens_count) tgt_orig_tokens_count = tf.cast(tf.reduce_sum(tgt_actual_seq_len), tf.float32) tgt_packed_tokens_count = tf.reduce_sum( tf.cast(tgt_segment_ids > 0, tf.float32)) summary_utils.scalar('examples/tgt_packed_token_ratio', tgt_packed_tokens_count / tgt_orig_tokens_count) # We deferred adding .paddings and use its complement .ids_indicator # exclusively so that we can apply the packing with padding set to 0 for all # fields. def ApplyPackingToSource(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input) return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input) src_paddings = ops.apply_packing(batch.src.paddings, 1, src_segment_ids, src_indices_in_input) batch.src = batch.src.Transform(ApplyPackingToSource) batch.src.paddings = src_paddings batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32) batch.src.segment_pos = src_segment_pos def ApplyPackingToTarget(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input) return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input) tgt_paddings = ops.apply_packing(batch.tgt.paddings, 1, tgt_segment_ids, tgt_indices_in_input) batch.tgt = batch.tgt.Transform(ApplyPackingToTarget) batch.tgt.paddings = tgt_paddings batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32) batch.tgt.segment_pos = tgt_segment_pos # The number of examples is indicated by the segment_ids of the target. num_segments = tf.math.reduce_max(batch.tgt.segment_ids, axis=1) num_examples = tf.reduce_sum(num_segments) # Note that this is per infeed value when p.use_per_host_infeed = True. metric_name = 'examples/num_packed_examples' summary_utils.scalar(metric_name, num_examples)
def PostTrainingStepUpdate(self, global_step): """Update the cap value.""" p = self.params cap = self.Value(global_step) summary_utils.scalar('cap', cap) return self.vars.cap.assign(cap)
def _Gradient(inputs, _, original_grad): # Compute the gradients for each loss w.r.t. the inputs. # TODO(jngiam): Look into whether TF dedups this computation. per_loss_grads = [] for loss, _ in self._losses: per_loss_grad = tf.gradients(loss, self._output_tensor)[0] if per_loss_grad is None: tf.logging.warning( 'Loss %s did not result in a gradient during ' 'GradDrop computation.', loss) else: per_loss_grads.append(per_loss_grad) if not per_loss_grads: raise ValueError('No valid gradients for GradDrop.') # Multiply the gradients with the inputs. grads = per_loss_grads if p.use_input_sign_only: input_abs = tf.abs( tf.cast(tf.abs(inputs) <= p.epsilon, tf.float32) + inputs) grads = [grad * ((inputs) / (input_abs)) for grad in grads] else: grads = [grad * inputs for grad in grads] # Sum gradient over batch, assuming that batch is always on dim 0. if p.marginalize_batch_dim: grads = [ tf.reduce_sum(grad, axis=0, keepdims=True) for grad in grads ] # First discretize all gradients into their sign values. grad_sign_positive = [ tf.cast(grad > 0.0, tf.float32) for grad in grads ] grad_sign_negative = [ tf.cast(grad < 0.0, tf.float32) for grad in grads ] # Calculate the probability of positive gradients based on equation (1) # in the GradDrop paper. grad_abs_sum = tf.add_n([tf.abs(grad) for grad in grads]) prob_pos = (tf.add_n(grads) / (2. * grad_abs_sum + p.epsilon)) # Implementation of different scales for the keep function. Larger # scales result in steeper keep functions. prob_pos *= p.keep_prob_function_scale if p.keep_prob_function == 'sigmoid': # Standard sigmoid has derivative of 0.25 at 0 so the factor of 4.0 # allows the function scale in sigmoid to be compatible with the # function scale in the linear case. prob_pos = tf.sigmoid(4.0 * prob_pos) elif p.keep_prob_function == 'linear': prob_pos += 0.5 # The main, default mode of GradDrop. Only gradients of one sign are kept, # and which sign is calculated via equation (1) of the main paper. prob_pos = tf.cast(prob_pos >= tf.random.uniform(prob_pos.shape), tf.float32) - 0.5 grad_masks = [ (gsp - gsn) * prob_pos >= 0 for (gsn, gsp) in zip(grad_sign_negative, grad_sign_positive) ] # This diag value gives us the percentage of grads which are kept. gradmask_diag = [tf.cast(gm, tf.float32) for gm in grad_masks] diag = tf.reduce_mean(tf.add_n(gradmask_diag) / len(grad_masks)) summary_utils.scalar('average_grad_mask', diag) leak_ratios = [leak_ratio for _, leak_ratio in self._losses] transformed_per_loss_grads = [ grad * (leak + (1.0 - leak) * tf.cast(grad_mask, tf.float32)) for (leak, grad, grad_mask) in zip(leak_ratios, per_loss_grads, grad_masks) ] transformed_grad = tf.cast(tf.add_n(transformed_per_loss_grads), original_grad.dtype) if not p.keep_gradnorm_constant: return transformed_grad transformed_grad_norm = tf.sqrt(tf.reduce_sum(transformed_grad**2)) original_grad_norm = tf.sqrt(tf.reduce_sum(original_grad**2)) return transformed_grad * original_grad_norm / ( transformed_grad_norm + p.epsilon)
def AddSummary(self, lr, optimizer, var_grad): summary_utils.scalar('adafactor_lr', lr)
def AddSummary(self, lr, optimizer, var_grad): summary_utils.scalar('sgd_lr', lr)
def AddSummary(self, lr, optimizer, var_grad): summary_utils.scalar('adagraft_lr', lr) if self.params.diagnostic: # verbose option m_step_norm_total = 0.0 d_step_norm_total = 0.0 for v, _ in var_grad.Flatten(): # record layer-wise gradient norms m_step_norm = optimizer.get_slot(v, 'm_step_norm') d_step_norm = optimizer.get_slot(v, 'd_step_norm') summary_utils.scalar('optimizer/m_step_norm_%s' % v.name, m_step_norm) summary_utils.scalar('optimizer/d_step_norm_%s' % v.name, d_step_norm) m_step_norm_total += m_step_norm**2 d_step_norm_total += d_step_norm**2 # record global gradient norms m_step_norm_total **= 0.5 d_step_norm_total **= 0.5 summary_utils.scalar('optimizer/m_step_norm', m_step_norm_total) summary_utils.scalar('optimizer/d_step_norm', d_step_norm_total) summary_utils.scalar('optimizer/norm_correction', m_step_norm_total / d_step_norm_total)