Ejemplo n.º 1
0
    def ComputeLoss(self, theta, predictions, input_batch):
        p = self.params
        batch = tf.shape(input_batch.data)[0]
        act = predictions.act
        with tf.colocate_with(act):
            tf.logging.info("{}'s device: {}".format(act, act.device))
            # Softmax
            labels = tf.to_int64(input_batch.label)
            onehot_labels = tf.one_hot(labels, p.softmax.num_classes)
            if p.label_smoothing > 0:
                smooth_positives = 1.0 - p.label_smoothing
                smooth_negatives = p.label_smoothing / p.softmax.num_classes
                onehot_labels = onehot_labels * smooth_positives + smooth_negatives

            xent = self.softmax.FProp(theta=theta.softmax,
                                      inputs=act,
                                      class_weights=input_batch.weight,
                                      class_probabilities=onehot_labels)

        self._AddSummary(input_batch, xent.per_example_argmax)

        rets = {
            'loss': (xent.avg_xent, batch),
            'log_pplx': (xent.avg_xent, batch),
            'num_preds': (batch, 1),
        }
        if p.is_eval or p.compute_accuracy_for_training:
            acc1 = self._Accuracy(1, xent.logits, labels, input_batch.weight)
            acc5 = self._Accuracy(5, xent.logits, labels, input_batch.weight)
            rets.update(accuracy=(acc1, batch), acc5=(acc5, batch))
        return rets, {}
Ejemplo n.º 2
0
    def _create_slots(self, var_list):
        self.magnitude_optimizer._create_slots(var_list)  # pylint: disable=protected-access
        self.direction_optimizer._create_slots(var_list)  # pylint: disable=protected-access

        for v in var_list:
            with tf.colocate_with(v):
                self._zeros_slot(v, "scratch_copy", self._name)
                if self.diagnostic or self.use_global_norm:
                    self._get_or_make_slot(v, tf.constant(0.0), "m_step_norm",
                                           self._name)
                    self._get_or_make_slot(v, tf.constant(0.0), "d_step_norm",
                                           self._name)
Ejemplo n.º 3
0
 def PostTrainingStepUpdate(self, global_step):
     """Updates moving_mean, moving_variance after each training step."""
     p = self.params
     # Get sufficient stats that accumulates over microbatches.
     counts = self.accumulators.counts.GetValue()
     mean_ss = self.accumulators.mean_ss.GetValue()
     variance_ss = self.accumulators.variance_ss.GetValue()
     # Compute batch mean and batch variance from sufficient stats
     mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss,
                                              None)
     decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype)
     # Update moving_mean, moving_variance from  batch mean and batch variance.
     with tf.name_scope(p.name) as scope:
         with tf.colocate_with(self.vars.moving_mean):
             mean_update = tf.assign_sub(
                 self.vars.moving_mean,
                 tf.where(tf.greater(counts, 0.5),
                          (self.vars.moving_mean - tf.cast(mean, p.dtype)) *
                          decay, tf.zeros_like(self.vars.moving_mean)),
                 name='moving_mean_update')
         with tf.colocate_with(self.vars.moving_variance):
             var_update = tf.assign_sub(
                 self.vars.moving_variance,
                 tf.where(tf.greater(counts, 0.5),
                          (self.vars.moving_variance -
                           tf.cast(variance, p.dtype)) * decay,
                          tf.zeros_like(self.vars.moving_variance)),
                 name='moving_variance_update')
         py_utils.CheckNumerics(
             self.vars.moving_mean,
             'moving mean of {} failed numeric check'.format(scope))
         py_utils.CheckNumerics(
             self.vars.moving_variance,
             'moving variance of {} failed numeric check'.format(scope))
     self.accumulators.counts.Reset()
     self.accumulators.mean_ss.Reset()
     self.accumulators.variance_ss.Reset()
     return tf.group(mean_update, var_update)
Ejemplo n.º 4
0
  def _BPropForVariables(self, vmap):
    """Constructs the backward graph."""
    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in six.iteritems(self._per_input_gradient_mask)
      }
    all_losses = []
    for optimization in self.learners:
      loss_name = optimization.params.name
      metric = self._metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(self._metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      for key, (value, weight) in six.iteritems(eval_metrics):
        self.AddEvalMetric(key + '/' + loss_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    # Get the op to update the weight masks and thresholds
    train_ops['mask_updates'] = self._GetMaskUpdateOp()

    # Post training step update.
    train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step)

    with tf.control_dependencies(tf.nest.flatten(train_ops)):
      true_global_step = py_utils.GetOrCreateGlobalStepVar()
      with tf.colocate_with(true_global_step):
        increment_global_steps = tf.assign_add(true_global_step, 1)
      if self._global_step_var != true_global_step:
        with tf.colocate_with(self._global_step_var):
          increment_global_steps = tf.group(
              increment_global_steps, tf.assign_add(self._global_step_var, 1))
      train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in six.iteritems(train_ops):
      assert op is not None, op_name

    # TODO(rpang): try to structure _train_op as:
    #   tf.cond(skip_step, <only update skip stats>, <all updates>)
    # so that we skip all other updates when a step is skipped.
    self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')