Exemple #1
0
  def _BPropForVariables(self, vmap):
    """Constructs the backward graph."""
    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in six.iteritems(self._per_input_gradient_mask)
      }
    all_losses = []
    for optimization in self.learners:
      loss_name = optimization.params.name
      metric = self._metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(self._metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      for key, (value, weight) in six.iteritems(eval_metrics):
        self.AddEvalMetric(key + '/' + loss_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    # Get the op to update the weight masks and thresholds
    train_ops['mask_updates'] = self._GetMaskUpdateOp()

    # Post training step update.
    train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step)

    with tf.control_dependencies(tf.nest.flatten(train_ops)):
      true_global_step = py_utils.GetOrCreateGlobalStepVar()
      with tf.colocate_with(true_global_step):
        increment_global_steps = tf.assign_add(true_global_step, 1)
      if self._global_step_var != true_global_step:
        with tf.colocate_with(self._global_step_var):
          increment_global_steps = tf.group(
              increment_global_steps, tf.assign_add(self._global_step_var, 1))
      train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in six.iteritems(train_ops):
      assert op is not None, op_name

    # TODO(rpang): try to structure _train_op as:
    #   tf.cond(skip_step, <only update skip stats>, <all updates>)
    # so that we skip all other updates when a step is skipped.
    self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
Exemple #2
0
  def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True):
    """Populates the train_ops dictionary in a backwards pass."""
    metrics = metrics or self._metrics

    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in self._per_input_gradient_mask.items()
      }
    all_losses = []
    for optimization in self.learners:
      learner_name = optimization.params.name
      loss_name = optimization.params.loss_name or learner_name
      metric = metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % learner_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      if add_summary:
        for key, (value, weight) in eval_metrics.items():
          self.AddEvalMetric(key + '/' + learner_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    var_update_ops = [
        tf.group(*tf.nest.flatten(train_ops), name='var_update_ops')
    ]
    # Post training step update.
    with tf.control_dependencies(var_update_ops):
      post_step_op = self.PostTrainingStepUpdate(self.global_step)

    train_ops = {}
    with tf.control_dependencies([post_step_op]):
      # Get the op to update the weight masks and thresholds
      mask_update_op = self._GetMaskUpdateOp()
      train_ops['mask_updates'] = mask_update_op
      with tf.control_dependencies([mask_update_op]):
        true_global_step = py_utils.GetOrCreateGlobalStepVar()
        with tf.ops.colocate_with(true_global_step):
          increment_global_steps = tf.assign_add(true_global_step, 1)
        if self._global_step_var != true_global_step:
          with tf.ops.colocate_with(self._global_step_var):
            increment_global_steps = tf.group(
                increment_global_steps, tf.assign_add(self._global_step_var, 1))
        train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in train_ops.items():
      assert op is not None, op_name
    return train_ops