Exemple #1
0
 def _CreateQStateVar(self, t_name, suffix, params):
   name = t_name + '_' + suffix
   assert name not in self._qvars, 'QState var already exists: %s' % name
   var_name = self._qvars_scope.name + '/' + name
   with tf.variable_scope(py_utils.GetGlobalVariableScope()):
     v = py_utils.CreateVariable(var_name, params, trainable=False)
   self._qvars[name] = v
   return v
Exemple #2
0
def CreateTaskGlobalStep(task_name):
  """Create if needed and return the global_step."""
  with tf.name_scope(None), tf.variable_scope(
      py_utils.GetGlobalVariableScope()):
    graph_collections = [tf.GraphKeys.GLOBAL_VARIABLES, 'TASK_GLOBAL_STEP']
    _, v = py_utils.CreateVariable(
        name=task_name + '_global_step',
        params=py_utils.WeightParams([], py_utils.WeightInit.Constant(0),
                                     tf.int64),
        trainable=False,
        collections=graph_collections)
    summary_utils.scalar(v.name, v)
    return v
Exemple #3
0
  def __init__(self, params):
    assert issubclass(params.cls, BaseTask)
    # Ensure global_step exists before calling super.
    py_utils.GetOrCreateGlobalStepVar()
    super().__init__(params)

    p = self.params

    self._encoder = None
    self._online_encoder = None
    self._decoder = None

    self._loss = None
    self._num_predictions = None
    self._train_op = None
    self._post_train_ops = []
    self._eval_metrics = {}
    self._per_example = {}

    # Create the gradient mask,
    self._per_input_gradient_mask = None

    if p.task_global_step:
      with tf.name_scope(None), tf.variable_scope(
          py_utils.GetGlobalVariableScope()):
        var_name = p.name + '_global_step'
        # Create the variable immediately.
        self._CreateVariableInternal(
            var_name,
            base_layer.CreateVariableMeta(
                var_params=py_utils.WeightParams(
                    [], py_utils.WeightInit.Constant(0), tf.int64),
                theta_fn=None,
                kwargs=dict(
                    trainable=False,
                    collections=[tf.GraphKeys.GLOBAL_VARIABLES])))
        summary_utils.scalar(var_name, self._private_vars[var_name])
        self._global_step_var = self._private_vars[var_name]
    else:
      self._global_step_var = py_utils.GetOrCreateGlobalStepVar()

    if p.input:
      # TODO(zhifengc): Consider a simpler way to ensure the input
      # generator stops after one epoch.
      if self.do_eval and p.eval:
        seq_inp = issubclass(p.input.cls,
                             base_input_generator.BaseInputGeneratorFromFiles)
        if p.input.num_samples == 0:
          # Dataset size is unknown. Computes eval summary based on num_samples.
          assert p.eval.samples_per_summary > 0
        elif (p.eval.samples_per_summary == 0) or (p.input.num_samples <
                                                   p.eval.samples_per_summary):
          # If we know the dataset size and we want to evaluate the full
          # set, we need to coordinate the input generator to flush out
          # all samples so the evaler and decoder compute metrics on the
          # whole set for each summary step.
          if seq_inp:
            p.input.flush_every_n = p.input.num_samples
          p.eval.samples_per_summary = p.input.num_samples
        if seq_inp and p.input.num_batcher_threads > 1:
          tf.logging.warning(
              'input.num_batcher_threads > 1 inside eval mode.  '
              'The input generator may not iterate over exactly '
              'one epoch per run')
      tf.logging.info('input_params: %s', p.input)
      input_params = self.cluster.PlaceInput(p.input)

      # For TPU training, we create the input generator in a
      # different scope and AddChild it in later.
      if 'skip_create_child' not in p.input:
        self.CreateChild('input', input_params)

    tp = p.train

    # p.train can be None if this task is the teacher/student task in a
    # DistillationTask.
    if tp:
      self._SetLearnerFromLegacyParams(tp)
      if tp.learner is not None:
        if isinstance(tp.learner, (list, tuple)):
          self.CreateChildren('learners', tp.learner)
        else:
          self.CreateChildren('learners', [tp.learner])
    self._UpdateVnConfig()