Exemplo n.º 1
0
    def testBatchNormUpdatesWithUpdateUseGlobalStatsForTraining(self):
        tf.random.set_seed(398847392)
        np.random.seed(12345)
        params = layers.BatchNormLayer.Params()
        params.name = 'bn'
        params.dim = 3
        params.use_moving_avg_in_training = True
        params.params_init = py_utils.WeightInit.Gaussian(0.1)

        bn_layer = layers.BatchNormLayer(params)
        in_padding1 = tf.zeros([2, 8, 1], dtype=tf.float32)
        bn_in1 = tf.constant(np.random.normal(0.1, 0.5, [2, 8, 3]),
                             dtype=tf.float32)

        bn_out = bn_layer.FPropDefaultTheta(bn_in1, in_padding1)
        sig1 = tf.reduce_sum(bn_out)
        sig2 = tf.reduce_sum(bn_out * bn_out)

        # IMPORTANT: Keep these values consistent with the corresponding
        # test in layers_test.py
        self.assertAllClose(2.6575434, sig1, atol=1e-5)
        self.assertAllClose(15.473802, sig2)

        updates_collection = tf.get_collection(py_utils.BATCH_NORM_UPDATES)
        l1, l2 = py_utils.FindRelevantBatchNormUpdates(bn_out,
                                                       updates_collection)
        self.assertEqual(l1, [])
        self.assertEqual(l2, [])
Exemplo n.º 2
0
  def _BPropForVariables(self, vmap):
    """Constructs the backward graph for the given variables.

    Args:
      vmap: a `.NestedMap` of variables.
    """
    p = self.params
    tp = p.train

    # Compute gradients.
    self._var_grads = py_utils.ComputeGradients(self.loss, vmap)

    # L2 regularizer.
    if tp.l2_regularizer_weight is not None:
      l2_loss, self._var_grads = py_utils.AdjustGradientsWithLpLoss(
          self._var_grads, tp.l2_regularizer_weight, p=2.0)
      summary_utils.scalar(p, 'l2_loss', l2_loss)

    # L1 regularizer.
    if tp.l1_regularizer_weight is not None:
      l1_loss, self._var_grads = py_utils.AdjustGradientsWithLpLoss(
          self._var_grads, tp.l1_regularizer_weight, p=1.0)
      summary_utils.scalar(p, 'l1_loss', l1_loss)

    # Mask gradients only if the mask is set.
    if self._per_input_gradient_mask:
      bprop_onehot = self.input_generator.GetInputSourceOneHot()
      self._var_grads = py_utils.MaskGradients(
          self._var_grads, self._per_input_gradient_mask, bprop_onehot)

    # Apply gradient clipping.
    has_nan_or_inf, _, self._var_grads = self.ScaleGradients(self._var_grads)

    # Histogram summary.
    summary_utils.CollectVarHistogram(p, self._var_grads)

    lrs = self.lr_schedule.Value(self._global_step)
    summary_utils.scalar(p, 'lr_schedule', lrs)
    lr = tp.learning_rate * lrs

    var_update_op = self.optimizer.Apply(lr, self._var_grads)

    increment_global_step_ops = []
    with tf.colocate_with(self._shared_global_step):
      increment_global_step_ops.append(
          tf.assign_add(self._shared_global_step, 1))
    if self._task_global_step:
      with tf.colocate_with(self._task_global_step):
        increment_global_step_ops.append(
            tf.assign_add(self._task_global_step, 1))
    increment_global_steps = tf.group(*increment_global_step_ops)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        self.loss, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    batch_norm_updates = tf.group(*relevant_bn_updates)

    # Update stats.
    stats_updates = tf.group(
        self.IncrementTotalSamples(),
        self.IncrementTotalNans(tf.to_int32(has_nan_or_inf)))

    # Post training step update.
    post_training_step_updates = self.PostTrainingStepUpdate(self._global_step)

    # Get the op to update the weight masks and thresholds
    mask_update_op = self._GetMaskUpdateOp()

    # TODO(rpang): try to structure _train_op as:
    #   tf.cond(skip_step, <only update skip stats>, <all updates>)
    # so that we skip all other updates when a step is skipped.
    
    # 
    if p.contiguous:
        var_update_op = tf.group(var_update_op, self.last_state_group_op)

    self._train_op = tf.group(
        var_update_op,
        batch_norm_updates,
        stats_updates,
        post_training_step_updates,
        increment_global_steps,
        mask_update_op,
        name='train')
Exemplo n.º 3
0
  def _BPropForVariables(self, vmap):
    """Constructs the backward graph."""
    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in six.iteritems(self._per_input_gradient_mask)
      }
    all_losses = []
    for optimization in self.learners:
      loss_name = optimization.params.name
      metric = self._metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(self._metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      for key, (value, weight) in six.iteritems(eval_metrics):
        self.AddEvalMetric(key + '/' + loss_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    # Get the op to update the weight masks and thresholds
    train_ops['mask_updates'] = self._GetMaskUpdateOp()

    # Post training step update.
    train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step)

    with tf.control_dependencies(tf.nest.flatten(train_ops)):
      true_global_step = py_utils.GetOrCreateGlobalStepVar()
      with tf.colocate_with(true_global_step):
        increment_global_steps = tf.assign_add(true_global_step, 1)
      if self._global_step_var != true_global_step:
        with tf.colocate_with(self._global_step_var):
          increment_global_steps = tf.group(
              increment_global_steps, tf.assign_add(self._global_step_var, 1))
      train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in six.iteritems(train_ops):
      assert op is not None, op_name

    # TODO(rpang): try to structure _train_op as:
    #   tf.cond(skip_step, <only update skip stats>, <all updates>)
    # so that we skip all other updates when a step is skipped.
    self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
Exemplo n.º 4
0
  def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True):
    """Populates the train_ops dictionary in a backwards pass."""
    metrics = metrics or self._metrics

    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in self._per_input_gradient_mask.items()
      }
    all_losses = []
    for optimization in self.learners:
      learner_name = optimization.params.name
      loss_name = optimization.params.loss_name or learner_name
      metric = metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % learner_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      if add_summary:
        for key, (value, weight) in eval_metrics.items():
          self.AddEvalMetric(key + '/' + learner_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    var_update_ops = [
        tf.group(*tf.nest.flatten(train_ops), name='var_update_ops')
    ]
    # Post training step update.
    with tf.control_dependencies(var_update_ops):
      post_step_op = self.PostTrainingStepUpdate(self.global_step)

    train_ops = {}
    with tf.control_dependencies([post_step_op]):
      # Get the op to update the weight masks and thresholds
      mask_update_op = self._GetMaskUpdateOp()
      train_ops['mask_updates'] = mask_update_op
      with tf.control_dependencies([mask_update_op]):
        true_global_step = py_utils.GetOrCreateGlobalStepVar()
        with tf.ops.colocate_with(true_global_step):
          increment_global_steps = tf.assign_add(true_global_step, 1)
        if self._global_step_var != true_global_step:
          with tf.ops.colocate_with(self._global_step_var):
            increment_global_steps = tf.group(
                increment_global_steps, tf.assign_add(self._global_step_var, 1))
        train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in train_ops.items():
      assert op is not None, op_name
    return train_ops
Exemplo n.º 5
0
  def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True):
    """Populates the train_ops dictionary in a backwards pass."""
    metrics = metrics or self._metrics

    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in self._per_input_gradient_mask.items()
      }
    all_losses = []
    for optimization in self.learners:
      learner_name = optimization.params.name
      (losses, train_ops['train/%s' % learner_name],
       eval_metrics) = optimization.Apply(
           metrics,
           vmap,
           gradient_mask=gradient_mask,
           gradient_adjuster=self.AdjustGradients)
      all_losses.extend(losses)
      if add_summary:
        for key, (value, weight) in eval_metrics.items():
          self.AddEvalMetric(key + '/' + learner_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    var_update_ops = [
        tf.group(*tf.nest.flatten(train_ops), name='var_update_ops')
    ]
    # Post training step update.
    with tf.control_dependencies(var_update_ops):
      post_step_op = self.PostTrainingStepUpdate()

    train_ops = {}
    with tf.control_dependencies([post_step_op]):
      # Get the op to update the weight masks and thresholds
      mask_update_op = self._GetMaskUpdateOp()
      train_ops['mask_updates'] = mask_update_op
      with tf.control_dependencies([mask_update_op]):
        true_global_step = py_utils.GetOrCreateGlobalStepVar()
        with tf.ops.colocate_with(true_global_step):
          if self.params.defer_global_step_update:
            increment_global_steps = true_global_step
          else:
            increment_global_steps = tf.assign_add(true_global_step, 1)
        if self._global_step_var != true_global_step:
          with tf.ops.colocate_with(self._global_step_var):
            increment_global_steps = tf.group(
                increment_global_steps, tf.assign_add(self._global_step_var, 1))
        train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    if tf.get_collection(py_utils.TPU_EMBEDDING):
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      sparse_grads = (
          tpu_embedding_gradient.get_gradients_through_dummy_table_variables(
              tpu_embedding))
      tpu_embedding_send_gradient_op = tpu_embedding.generate_send_gradients_op(
          sparse_grads, py_utils.GetGlobalStep())
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

      tpu_embedding_summary_tensors = tf.get_collection(
          py_utils.TPU_EMBEDDING_SUMMARY_TENSORS)
      if add_summary:
        for name, value, weight in tpu_embedding_summary_tensors:
          self.AddEvalMetric(name, value, weight, raise_if_already_added=False)

    for op_name, op in train_ops.items():
      assert op is not None, op_name
    return train_ops