Exemple #1
0
 def begin(self):
     self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
     self._global_episode_tensor = get_global_episode()
     if self._global_episode_tensor is None:
         raise RuntimeError(
             "Global step should be created to use CheckpointSaverHook.")
     for l in self._listeners:
         l.begin()
Exemple #2
0
 def begin(self):
     if self._summary_writer is None and self._output_dir:
         self._summary_writer = SummaryWriterCache.get(self._output_dir)
     self._next_episode = None
     self._current_episode = None
     self._global_episode_tensor = get_global_episode()
     if self._global_episode_tensor is None:
         raise RuntimeError("Global episode should be created to use EpisodeSummarySaverHook.")
Exemple #3
0
 def begin(self):
     if self._summary_writer is None and self._output_dir:
         self._summary_writer = SummaryWriterCache.get(self._output_dir)
     self._global_episode_tensor = get_global_episode()
     if self._global_episode_tensor is None:
         raise RuntimeError("Global step should be created to use EpisodeCounterHook.")
     self._summary_sec_tag = self._global_episode_tensor.op.name + "/sec"
     self._summary_steps_tag = self._global_episode_tensor.op.name + "/steps"
     self._num_steps = 0
Exemple #4
0
    def begin(self):
        self._global_episode_tensor = get_global_episode()
        if self._global_episode_tensor is None:
            raise RuntimeError("Global episode should be created to use StopAtEpisodeHook.")

        # Convert names to tensors if given
        self._current_tensors = {
            tag: basic_session_run_hooks._as_graph_element(tensor)  # pylint: disable=protected-access
            for (tag, tensor) in self._tensors.items()}
        self._current_tensors['global_episode'] = self._global_episode_tensor
Exemple #5
0
 def end(self, session):
     last_episode = session.run(get_global_episode())
     if last_episode != self._timer.last_triggered_episode():
         self._save(last_episode, session)
     for l in self._listeners:
         l.end(session, last_episode)
Exemple #6
0
 def begin(self):
     self._global_episode_tensor = get_global_episode()
     if self._global_episode_tensor is None:
         raise RuntimeError("Global episode should be created to use StopAtEpisodeHook.")