示例#1
0
    def __init__(self, decoder_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._job_name = 'decoder_' + decoder_type
        self.params.cluster.do_eval = True
        self._cluster = cluster_factory.Cluster(self.params.cluster)
        self._decoder_dir = GetDecoderDir(self._logdir, self._job_name,
                                          self._model_task_name)
        tf.io.gfile.makedirs(self._decoder_dir)

        self._decode_path = None
        # Multitask params doesn't have 'task'.
        if 'task' in self.params:
            self._decode_path = checkpointer.GetSpecificCheckpoint(
                self.params.task.eval.load_checkpoint_from)

        self._should_report_metrics = self._job_name.startswith(
            self._cluster.reporting_job)

        with self._graph.as_default(), tf.container(self._container_id):
            self._summary_writer = self._CreateSummaryWriter(self._decoder_dir)
            self._CreateTF2SummaryWriter(self._decoder_dir)
            with self._cluster, tf.device(
                    self._cluster.GetPlacer()), self._TF2SummaryContext():
                self._model = self.params.Instantiate()
                self._params = self._model.params
                self._task = self._model.GetTask(self._model_task_name)
                # Note, different graphs are being constructed for different model
                # tasks, which may result in different node names being chosen.
                # Obviously, variable names has to be stay the same between train and
                # decode.
                cluster = self._cluster
                with tf.device(cluster.input_device):
                    input_batch = (
                        self._task.input_generator.GetPreprocessedInputBatch())

                self._dec_output = self._task.Decode(input_batch)
                self._summary_op = tf.summary.merge_all()
                self.checkpointer = self._CreateCheckpointer(
                    self._train_dir, self._model)
            self._CreateTF2SummaryOps()
            self._initialize_tables = tf.tables_initializer()
            self._initialize_local_vars = tf.local_variables_initializer()
            # No queues are allowed for decoder models.
            self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS)
            assert not self.enqueue_ops

        # Saves the graph def.
        self._WriteToLog(self.params.ToText(), self._decoder_dir, 'params.txt')
        if self.params.cluster.task == 0:
            tf.io.write_graph(self._graph.as_graph_def(), self._decoder_dir,
                              '%s.pbtxt' % self._job_name)
示例#2
0
  def Start(self):
    """Start."""
    super().Start()

    with self._cluster:
      self._model = self._params.Instantiate()
      self._checkpointer = self._CreateCheckpointer(self._train_dir,
                                                    self._model)
      self._task = self._model.GetTask(self._model_task_name)

    self._decode_path = checkpointer.GetSpecificCheckpoint(
        self._task.params.eval.load_checkpoint_from)

    if self._decode_path:
      self._DecodeOnce(path=self._decode_path)
      py_utils.UpdateProcessedCheckpoints(self._decoder_dir, self._decode_path)
    elif self._task.params.eval.decode_all_checkpoints:
      self._RunOnAllCheckpoints(
          runner_fn=self._DecodeOnce, runner_dir=self._decoder_dir)
    else:
      self._RunOnLatestCheckpoints(
          runner_fn=self._DecodeOnce, runner_dir=self._decoder_dir)