Beispiel #1
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
     # extend session.run(ops) so the ops can be excute parallel
    
     if self.rank != self.root_rank:
         return basic_session_run_hooks.SessionRunArgs(fetches={})
         
     # only root print log
     return basic_session_run_hooks.SessionRunArgs(fetches=self.fetches)
Beispiel #2
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
     # extend session.run(ops) so the ops can be excute parallel
     self.fetches.update({
         'global_step': self.global_step,
         'run_ops': self.run_ops
     })
     return basic_session_run_hooks.SessionRunArgs(fetches=self.fetches)
Beispiel #3
0
    def before_run(self, run_context):  # pylint: disable=unused-argument
        requests = {"global_episode": self._global_episode_tensor}
        if can_run_hook(run_context):
            self._request_summary = self._current_episode == self._next_episode
            if self._request_summary:
                if self._get_summary_op() is not None:
                    requests["summary"] = self._get_summary_op()

        return basic_session_run_hooks.SessionRunArgs(requests)
Beispiel #4
0
    def before_run(self, run_context):  # pylint: disable=unused-argument
        if can_run_hook(run_context) and self._timer.last_triggered_episode() is None:
            # We do write graph and saver_def at the first call of before_run.
            # We cannot do this in begin, since we let other hooks to change graph and
            # add variables in begin. Graph is finalized after all begin calls.
            training_util.write_graph(
                tf.get_default_graph().as_graph_def(add_shapes=True),
                self._checkpoint_dir,
                "graph.pbtxt")
            saver_def = self._get_saver().saver_def if self._get_saver() else None
            graph = tf.get_default_graph()
            meta_graph_def = meta_graph.create_meta_graph_def(
                graph_def=graph.as_graph_def(add_shapes=True),
                saver_def=saver_def)
            self._summary_writer.add_graph(graph)
            self._summary_writer.add_meta_graph(meta_graph_def)

        return basic_session_run_hooks.SessionRunArgs(self._global_episode_tensor)
 def before_run(self, run_context):  # pylint: disable=unused-argument
     # extend session.run(ops) so the ops can be excute parallel
     return basic_session_run_hooks.SessionRunArgs(fetches=self.fetches)
 def before_run(self, run_context):
     fetches = {
         'summary': self.merged_ops,
         'gloal_step': self._global_step_tensor
     }
     return basic_session_run_hooks.SessionRunArgs(fetches=fetches)
Beispiel #7
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
     return basic_session_run_hooks.SessionRunArgs(self._global_step_tensor)
 def before_run(self, run_context):  # pylint: disable=unused-argument
     return basic_session_run_hooks.SessionRunArgs(self._metrics.values())
 def before_run(self, run_context):
     total_losses = tf.add_n(tf.get_collection("total_losses"))
     #self._global_step_tensor,
     return basic_session_run_hooks.SessionRunArgs(
         [self._global_step_tensor, total_losses])
Beispiel #10
0
 def before_run(self, run_context):
     return basic_session_run_hooks.SessionRunArgs(self._global_step_tensor)