コード例 #1
0
ファイル: agent_vpg.py プロジェクト: Jeyhooon/gdrl
    def demo_last(self, title='{} Agent - Fully Trained ', n_episodes=3, max_n_videos=3):
        env = self.make_env_fn(**self.make_env_kargs, monitor_mode='evaluation', render=True, record=True)
        title = title.format(self.__class__.__name__)

        checkpoint_paths = self.get_cleaned_checkpoints()
        last_ep = max(checkpoint_paths.keys())
        self.policy_model.load_state_dict(torch.load(checkpoint_paths[last_ep]))

        self.evaluate(self.policy_model, env, n_episodes=n_episodes)
        env.close()
        html_data = utils.get_gif_html(env_videos=env.videos,
                                       title=title,
                                       max_n_videos=max_n_videos)
        del env
        return html_data, title
コード例 #2
0
ファイル: agent_vpg.py プロジェクト: Jeyhooon/gdrl
    def demo_progression(self, title='{} Agent - Progression', max_n_videos=5):
        env = self.make_env_fn(**self.make_env_kargs, monitor_mode='evaluation', render=True, record=True)
        title = title.format(self.__class__.__name__)

        checkpoint_paths = self.get_cleaned_checkpoints()
        for i in sorted(checkpoint_paths.keys()):
            self.policy_model.load_state_dict(torch.load(checkpoint_paths[i]))
            self.evaluate(self.policy_model, env, n_episodes=1)

        env.close()
        html_data = utils.get_gif_html(env_videos=env.videos,
                                       title=title,
                                       subtitle_eps=sorted(checkpoint_paths.keys()),
                                       max_n_videos=max_n_videos)
        del env
        return html_data, title