Example #1
0
    def before_run(self, run_context):  # pylint: disable=unused-argument
        requests = {"global_episode": self._global_episode_tensor}
        if can_run_hook(run_context):
            self._request_summary = self._current_episode >= self._next_episode
            if self._request_summary:
                if self._get_summary_op() is not None:
                    requests["summary"] = self._get_summary_op()

        return basic_session_run_hooks.SessionRunArgs(requests)
Example #2
0
    def before_run(self, run_context):  # pylint: disable=unused-argument
        if can_run_hook(run_context) and self._timer.last_triggered_episode() is None:
            # We do write graph and saver_def at the first call of before_run.
            # We cannot do this in begin, since we let other hooks to change graph and
            # add variables in begin. Graph is finalized after all begin calls.
            training_util.write_graph(
                tf.get_default_graph().as_graph_def(add_shapes=True),
                self._checkpoint_dir,
                "graph.pbtxt")
            saver_def = self._get_saver().saver_def if self._get_saver() else None
            graph = tf.get_default_graph()
            meta_graph_def = meta_graph.create_meta_graph_def(
                graph_def=graph.as_graph_def(add_shapes=True),
                saver_def=saver_def)
            self._summary_writer.add_graph(graph)
            self._summary_writer.add_meta_graph(meta_graph_def)

        return basic_session_run_hooks.SessionRunArgs(self._global_episode_tensor)
Example #3
0
 def after_run(self, run_context, run_values):
     global_episode = run_values.results['global_episode']
     if can_run_hook(run_context):
         if self._timer.should_trigger_for_episode(global_episode):
             original = np.get_printoptions()
             np.set_printoptions(suppress=True)
             elapsed_secs, _ = self._timer.update_last_triggered_episode(global_episode)
             if self._formatter:
                 logging.info(self._formatter(run_values.results))
             else:
                 stats = []
                 for tag in self._tag_order:
                     stats.append("%s = %s" % (tag, run_values.results[tag]))
                 if elapsed_secs is not None:
                     logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
                 else:
                     logging.info("%s", ", ".join(stats))
             np.set_printoptions(**original)
Example #4
0
 def after_run(self, run_context, run_values):
     if can_run_hook(run_context):
         return super(NanTensorHook, self).after_run(run_context, run_values)
Example #5
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
     if can_run_hook(run_context):
         return super(NanTensorHook, self).before_run(run_context)
     return None
Example #6
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
     self._should_trigger = can_run_hook(run_context)
     if self._should_trigger:
         return super(StepLoggingTensorHook, self).before_run(run_context)
     else:
         return None
Example #7
0
 def after_run(self, run_context, run_values):
     global_episode = run_values.results
     if can_run_hook(run_context) and self._timer.should_trigger_for_episode(global_episode):
         self._timer.update_last_triggered_episode(global_episode)
         self._save(global_episode, run_context.session)
Example #8
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
     if can_run_hook(run_context):
         return session_run_hook.SessionRunArgs(self._current_tensors)
     else:
         return session_run_hook.SessionRunArgs({'global_episode': self._global_episode_tensor})