示例#1
0
    def _save(self, agent_context: core.AgentContext):
        """Saves the current policy in directory."""
        assert agent_context
        assert agent_context.train, "TrainContext not set."

        tc = agent_context.train
        min_rewards, avg_reward, max_rewards = tc.eval_rewards[
            tc.episodes_done_in_training]
        current_dir = f'episode_{tc.episodes_done_in_training}-avg_reward_{avg_reward}'
        current_dir = os.path.join(self.directory, current_dir)
        agent_context._agent_saver(directory=current_dir)
        self.saved_agents.append(
            (tc.episodes_done_in_training, avg_reward, current_dir))
示例#2
0
 def on_play_end(self, agent_context: core.AgentContext):
     if agent_context._is_plot_ready(core.PlotType.TRAIN_EVAL):
         self._display_plots(agent_context)
     if agent_context.is_play:
         self._display_plots(agent_context)
         if on_play_end_clear_jupyter_display:
             self._clear_jupyter_plots(agent_context, wait=False)
示例#3
0
 def _refresh_subplot(self, agent_context: core.AgentContext,
                      plot_type: core.PlotType):
     """Sets this axes active and calls plot if this plot callback is registered on at least 1 plot
         out of plot_type."""
     assert self.axes is not None
     plot_type = plot_type & self._plot_type
     if agent_context._is_plot_ready(plot_type):
         pyc = agent_context.pyplot
         if not pyc.is_jupyter_active:
             plt.figure(pyc.figure.number)
             if plt.gcf() is pyc.figure:
                 plt.sca(self.axes)
         self.plot(agent_context, plot_type)
示例#4
0
 def on_play_end(self, agent_context: core.AgentContext):
     if agent_context._is_plot_ready(core.PlotType.PLAY_EPISODE):
         self._close(agent_context)
示例#5
0
 def on_train_iteration_end(self, agent_context: core.AgentContext):
     if agent_context._is_plot_ready(core.PlotType.TRAIN_ITERATION):
         self._write_figure_to_video(agent_context)
示例#6
0
 def on_play_step_end(self, agent_context: core.AgentContext, action,
                      step_result: Tuple):
     if agent_context._is_plot_ready(core.PlotType.PLAY_STEP):
         self._write_figure_to_video(agent_context)
示例#7
0
 def on_play_episode_end(self, agent_context: core.AgentContext):
     if agent_context._is_plot_ready(core.PlotType.PLAY_EPISODE
                                     | core.PlotType.TRAIN_EVAL):
         self._write_figure_to_video(agent_context)
示例#8
0
 def on_train_iteration_end(self, agent_context: core.AgentContext):
     if agent_context._is_plot_ready(core.PlotType.TRAIN_ITERATION):
         self._display_plots(agent_context)
示例#9
0
 def on_train_iteration_begin(self, agent_context: core.AgentContext):
     # display initial evaluation before training starts.
     if agent_context.train.iterations_done_in_training == 0 and \
             agent_context._is_plot_ready(core.PlotType.TRAIN_EVAL):
         self._display_plots(agent_context)
示例#10
0
 def on_play_step_end(self, agent_context: core.AgentContext, action,
                      step_result: Tuple):
     if agent_context._is_plot_ready(core.PlotType.PLAY_STEP):
         self._display_plots(agent_context)
示例#11
0
 def on_play_episode_end(self, agent_context: core.AgentContext):
     if agent_context._is_plot_ready(core.PlotType.PLAY_EPISODE):
         self._display_plots(agent_context)
示例#12
0
 def on_play_begin(self, agent_context: core.AgentContext):
     agent_context.play.max_steps_per_episode = self._max_steps_per_episode
     if isinstance(agent_context, core.EpisodesTrainContext):
         agent_context.num_episodes_per_iteration = self._num_episodes_per_iteration