예제 #1
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._job_name = 'trainer'
        with self._graph.as_default(), tf.container(self._container_id):
            try:
                self._task_probs_summary_writers = []
                for task in self._model.task_schedule.tasks:
                    path = os.path.join(os.path.join(self._train_dir, task))
                    tf.io.gfile.makedirs(path)
                    self._task_probs_summary_writers.append(
                        self._CreateSummaryWriter(path))
            except AttributeError:
                tf.logging.info(
                    'AttributeError. Expected for single task models.')
                self._task_probs_summary_writers = []

            if self.params.cluster.task == 0:
                self._summary_writer = self._CreateSummaryWriter(
                    self._train_dir)
                self._CreateTF2SummaryWriter(self._train_dir)
            else:
                self._summary_writer = None

            with self._cluster, tf.device(
                    self._cluster.GetPlacer()), self._TF2SummaryContext():
                self._model = self.params.Instantiate()
                self._params = self._model.params
                self._model.ConstructFPropBPropGraph()
            self._CreateTF2SummaryOps()
            self._initialize_tables = tf.tables_initializer()
            self._initialize_local_vars = tf.local_variables_initializer()
            self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS)
            tf.logging.info('Trainer number of enqueue ops: %d',
                            len(self.enqueue_ops))

        self._step_rate_tracker = summary_utils.StepRateTracker()

        # Saves the graph def.
        if self.params.cluster.task == 0:
            self._WriteToLog(self.params.ToText(), self._train_dir,
                             'trainer_params.txt')
            tf.io.write_graph(self._graph.as_graph_def(), self._train_dir,
                              'train.pbtxt')
        worker_id = self.params.cluster.task
        self._start_up_delay_steps = (((worker_id + 1) * worker_id / 2) *
                                      self.params.train.start_up_delay_steps)
예제 #2
0
파일: program.py 프로젝트: Singed-jj/lingvo
 def __init__(self, params):
     super(TrainProgram, self).__init__(params)
     self._step_rate_tracker = summary_utils.StepRateTracker()
     self._program_name = 'TrainProgram'
예제 #3
0
 def __init__(self, params, shared_model=None):
     super().__init__(params, shared_model=shared_model)
     self._step_rate_tracker = summary_utils.StepRateTracker()
     self._program_name = 'TrainProgram'
예제 #4
0
  def Start(self):
    """Run training."""
    super().Start()

    with self._cluster:
      model = self._params.Instantiate()
      ckptr = self._CreateCheckpointer(self._train_dir, model)
      task = model.GetTask(self._model_task_name)

      @tf.function(autograph=False)
      def TrainFunc():
        with py_utils.GradientTape(persistent=True):
          model.ConstructFPropBPropGraph()
        return task.eval_metrics, task.per_example_tensors

      step_rate_tracker = summary_utils.StepRateTracker()
      summary_writer = tf.compat.v2.summary.create_file_writer(self._train_dir)

      # Attempt to restore the checkpoint
      # A 'dummy run' to initialze the optimizer and related slot variables
      # This is also needed for V2 checkpoint even though it supports delayed
      # loading, in case the checkpoint already exeeds max_steps. In that
      # scenario the slot variables will be lost without a dummy run due to
      # checkpoint overwrites.
      _, _ = TrainFunc()
      path = ckptr.Restore()
      if path:
        tf.logging.info(f'Loaded checkpoints from {path}.')
      else:
        tf.logging.info('Did not find any checkpoints. Starting fresh.')
        # Reset global_step manually if we could not load any checkpoints
        global_step = py_utils.GetOrCreateGlobalStepVar()
        global_step.assign(0)

      global_step = model.global_step.numpy()
      while True:
        if self._ShouldStop(global_step):
          break

        tf.logging.info('Starting train function.')
        metrics_dict, outfeed = TrainFunc()
        tf.logging.info('Train function complete.')

        global_step = model.global_step.numpy()

        if not task.per_example_tensors:
          assert not outfeed
        else:
          # TODO(laigd): debugging only, remove later.
          tf.logging.info(f'outfeed: {outfeed}')

        ckptr.MaybeSave(gsteps=global_step)

        step_rate, example_rate, total_examples = (
            step_rate_tracker.ComputeStepRate(
                global_step, metrics_dict['num_samples_in_batch'][0].numpy()))

        msg = 'step:%6d, steps/sec: %0.2f, examples/sec: %0.2f' % (
            global_step, step_rate, example_rate)
        # Write summaries.
        with summary_writer.as_default():
          tf.compat.v2.summary.scalar(
              'global_step/sec', step_rate, step=global_step)
          tf.compat.v2.summary.scalar(
              'examples/sec', example_rate, step=global_step)
          tf.compat.v2.summary.scalar(
              'total_samples', total_examples, step=global_step)
          for key, (val, _) in sorted(metrics_dict.items()):
            msg += ' %s:%.8g' % (key, val)
            tf.compat.v2.summary.scalar(key, val, step=global_step)
          summary_writer.flush()

        # Log training progress.
        self._SetStatusMessage(msg)

      # Also save at the end of training
      ckptr.Save(gsteps=global_step)