def ComputeLoss(self, theta, predictions, input_batch): p = self.params batch = tf.shape(input_batch.data)[0] act = predictions.act with tf.colocate_with(act): tf.logging.info("{}'s device: {}".format(act, act.device)) # Softmax labels = tf.to_int64(input_batch.label) onehot_labels = tf.one_hot(labels, p.softmax.num_classes) if p.label_smoothing > 0: smooth_positives = 1.0 - p.label_smoothing smooth_negatives = p.label_smoothing / p.softmax.num_classes onehot_labels = onehot_labels * smooth_positives + smooth_negatives xent = self.softmax.FProp(theta=theta.softmax, inputs=act, class_weights=input_batch.weight, class_probabilities=onehot_labels) self._AddSummary(input_batch, xent.per_example_argmax) rets = { 'loss': (xent.avg_xent, batch), 'log_pplx': (xent.avg_xent, batch), 'num_preds': (batch, 1), } if p.is_eval or p.compute_accuracy_for_training: acc1 = self._Accuracy(1, xent.logits, labels, input_batch.weight) acc5 = self._Accuracy(5, xent.logits, labels, input_batch.weight) rets.update(accuracy=(acc1, batch), acc5=(acc5, batch)) return rets, {}
def _create_slots(self, var_list): self.magnitude_optimizer._create_slots(var_list) # pylint: disable=protected-access self.direction_optimizer._create_slots(var_list) # pylint: disable=protected-access for v in var_list: with tf.colocate_with(v): self._zeros_slot(v, "scratch_copy", self._name) if self.diagnostic or self.use_global_norm: self._get_or_make_slot(v, tf.constant(0.0), "m_step_norm", self._name) self._get_or_make_slot(v, tf.constant(0.0), "d_step_norm", self._name)
def PostTrainingStepUpdate(self, global_step): """Updates moving_mean, moving_variance after each training step.""" p = self.params # Get sufficient stats that accumulates over microbatches. counts = self.accumulators.counts.GetValue() mean_ss = self.accumulators.mean_ss.GetValue() variance_ss = self.accumulators.variance_ss.GetValue() # Compute batch mean and batch variance from sufficient stats mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None) decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype) # Update moving_mean, moving_variance from batch mean and batch variance. with tf.name_scope(p.name) as scope: with tf.colocate_with(self.vars.moving_mean): mean_update = tf.assign_sub( self.vars.moving_mean, tf.where(tf.greater(counts, 0.5), (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay, tf.zeros_like(self.vars.moving_mean)), name='moving_mean_update') with tf.colocate_with(self.vars.moving_variance): var_update = tf.assign_sub( self.vars.moving_variance, tf.where(tf.greater(counts, 0.5), (self.vars.moving_variance - tf.cast(variance, p.dtype)) * decay, tf.zeros_like(self.vars.moving_variance)), name='moving_variance_update') py_utils.CheckNumerics( self.vars.moving_mean, 'moving mean of {} failed numeric check'.format(scope)) py_utils.CheckNumerics( self.vars.moving_variance, 'moving variance of {} failed numeric check'.format(scope)) self.accumulators.counts.Reset() self.accumulators.mean_ss.Reset() self.accumulators.variance_ss.Reset() return tf.group(mean_update, var_update)
def _BPropForVariables(self, vmap): """Constructs the backward graph.""" bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in six.iteritems(self._per_input_gradient_mask) } all_losses = [] for optimization in self.learners: loss_name = optimization.params.name metric = self._metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(self._metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) for key, (value, weight) in six.iteritems(eval_metrics): self.AddEvalMetric(key + '/' + loss_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates # Get the op to update the weight masks and thresholds train_ops['mask_updates'] = self._GetMaskUpdateOp() # Post training step update. train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step) with tf.control_dependencies(tf.nest.flatten(train_ops)): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in six.iteritems(train_ops): assert op is not None, op_name # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')