Exemplo n.º 1
0
    def learn(self):
        if self.resume:
            self.load_agent_checkpoint()
        else:
            # delete tensorboard log file
            log_tools.del_all_files_in_dir(self.result_dir)
        explore_before_train(self.env, self.replay_buffer, self.explore_step)
        print(
            "==============================start train==================================="
        )
        obs = self.env.reset()
        done = False

        episode_reward = 0
        episode_length = 0

        while self.train_step < self.max_train_step:
            action = self.choose_action(np.array(obs))
            # print(action)
            next_obs, reward, done, info = self.env.step(action)
            episode_reward += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done)
            obs = next_obs
            episode_length += 1

            actor_loss, critic_loss = self.train()

            if done:
                obs = self.env.reset()
                done = False
                self.episode_num += 1

                print(
                    f"Total T: {self.train_step} Episode Num: {self.episode_num} "
                    f"Episode Length: {episode_length} Episode Reward: {episode_reward:.3f}"
                )
                self.tensorboard_writer.log_learn_data(
                    {
                        "episode_length": episode_length,
                        "episode_reward": episode_reward
                    }, self.train_step)
                episode_reward = 0
                episode_length = 0

            if self.train_step % self.log_interval == 0:
                self.store_agent_checkpoint()
                self.tensorboard_writer.log_train_data(
                    {
                        "actor_loss": actor_loss,
                        "critic_loss": critic_loss
                    }, self.train_step)

            if self.eval_freq > 0 and self.train_step % self.eval_freq == 0:
                avg_reward, avg_length = evaluate(agent=self, episode_num=10)
                self.tensorboard_writer.log_eval_data(
                    {
                        "eval_episode_length": avg_length,
                        "eval_episode_reward": avg_reward
                    }, self.train_step)
Exemplo n.º 2
0
    def learn(self):
        if self.resume:
            self.load_agent_checkpoint()
        else:
            # delete tensorboard log file
            log_tools.del_all_files_in_dir(self.result_dir)
        explore_before_train(self.env, self.replay_buffer, self.explore_step)
        print(
            "==============================start train==================================="
        )
        obs = self.env.reset()
        done = False

        episode_reward = 0
        episode_length = 0

        while self.train_step < self.max_train_step:
            action, _ = self.choose_action(np.array(obs))
            next_obs, reward, done, info = self.env.step(action)
            episode_reward += reward
            self.replay_buffer.add(obs, action, reward, next_obs, done)
            obs = next_obs
            episode_length += 1

            q_loss1, q_loss2, policy_loss, alpha_loss = self.train()
            if done:
                obs = self.env.reset()
                done = False
                self.episode_num += 1

                print(
                    f"Total T: {self.train_step} Episode Num: {self.episode_num} "
                    f"Episode Length: {episode_length} Episode Reward: {episode_reward:.3f}"
                )
                self.tensorboard_writer.log_learn_data(
                    {
                        "episode_length": episode_length,
                        "episode_reward": episode_reward
                    }, self.train_step)
                episode_reward = 0
                episode_length = 0

            if self.train_step % self.log_interval == 0:
                self.store_agent_checkpoint()
                self.tensorboard_writer.log_train_data(
                    {
                        "q_loss_1": q_loss1,
                        "q_loss_2": q_loss2,
                        "policy_loss": policy_loss,
                        "alpha_loss": alpha_loss
                    }, self.train_step)