def __init__(self, decoder_type, *args, **kwargs): super().__init__(*args, **kwargs) self._job_name = 'decoder_' + decoder_type self.params.cluster.do_eval = True self._cluster = cluster_factory.Cluster(self.params.cluster) self._decoder_dir = GetDecoderDir(self._logdir, self._job_name, self._model_task_name) tf.io.gfile.makedirs(self._decoder_dir) self._decode_path = None # Multitask params doesn't have 'task'. if 'task' in self.params: self._decode_path = checkpointer.GetSpecificCheckpoint( self.params.task.eval.load_checkpoint_from) self._should_report_metrics = self._job_name.startswith( self._cluster.reporting_job) with self._graph.as_default(), tf.container(self._container_id): self._summary_writer = self._CreateSummaryWriter(self._decoder_dir) self._CreateTF2SummaryWriter(self._decoder_dir) with self._cluster, tf.device( self._cluster.GetPlacer()), self._TF2SummaryContext(): self._model = self.params.Instantiate() self._params = self._model.params self._task = self._model.GetTask(self._model_task_name) # Note, different graphs are being constructed for different model # tasks, which may result in different node names being chosen. # Obviously, variable names has to be stay the same between train and # decode. cluster = self._cluster with tf.device(cluster.input_device): input_batch = ( self._task.input_generator.GetPreprocessedInputBatch()) self._dec_output = self._task.Decode(input_batch) self._summary_op = tf.summary.merge_all() self.checkpointer = self._CreateCheckpointer( self._train_dir, self._model) self._CreateTF2SummaryOps() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() # No queues are allowed for decoder models. self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) assert not self.enqueue_ops # Saves the graph def. self._WriteToLog(self.params.ToText(), self._decoder_dir, 'params.txt') if self.params.cluster.task == 0: tf.io.write_graph(self._graph.as_graph_def(), self._decoder_dir, '%s.pbtxt' % self._job_name)
def Start(self): """Start.""" super().Start() with self._cluster: self._model = self._params.Instantiate() self._checkpointer = self._CreateCheckpointer(self._train_dir, self._model) self._task = self._model.GetTask(self._model_task_name) self._decode_path = checkpointer.GetSpecificCheckpoint( self._task.params.eval.load_checkpoint_from) if self._decode_path: self._DecodeOnce(path=self._decode_path) py_utils.UpdateProcessedCheckpoints(self._decoder_dir, self._decode_path) elif self._task.params.eval.decode_all_checkpoints: self._RunOnAllCheckpoints( runner_fn=self._DecodeOnce, runner_dir=self._decoder_dir) else: self._RunOnLatestCheckpoints( runner_fn=self._DecodeOnce, runner_dir=self._decoder_dir)