Exemplo n.º 1
0
    def learn(self):
        if self.resume:
            self.load_agent_checkpoint()
        else:
            # delete tensorboard log file
            log_tools.del_all_files_in_dir(self.result_dir)
        explore_before_train(self.env, self.replay_buffer, self.explore_step)
        print(
            "==============================start train==================================="
        )
        obs = self.env.reset()
        done = False

        episode_reward = 0
        episode_length = 0

        while self.train_step < self.max_train_step:
            action = self.choose_action(np.array(obs))
            # print(action)
            next_obs, reward, done, info = self.env.step(action)
            episode_reward += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done)
            obs = next_obs
            episode_length += 1

            actor_loss, critic_loss = self.train()

            if done:
                obs = self.env.reset()
                done = False
                self.episode_num += 1

                print(
                    f"Total T: {self.train_step} Episode Num: {self.episode_num} "
                    f"Episode Length: {episode_length} Episode Reward: {episode_reward:.3f}"
                )
                self.tensorboard_writer.log_learn_data(
                    {
                        "episode_length": episode_length,
                        "episode_reward": episode_reward
                    }, self.train_step)
                episode_reward = 0
                episode_length = 0

            if self.train_step % self.log_interval == 0:
                self.store_agent_checkpoint()
                self.tensorboard_writer.log_train_data(
                    {
                        "actor_loss": actor_loss,
                        "critic_loss": critic_loss
                    }, self.train_step)

            if self.eval_freq > 0 and self.train_step % self.eval_freq == 0:
                avg_reward, avg_length = evaluate(agent=self, episode_num=10)
                self.tensorboard_writer.log_eval_data(
                    {
                        "eval_episode_length": avg_length,
                        "eval_episode_reward": avg_reward
                    }, self.train_step)
Exemplo n.º 2
0
    def learn(self):
        """Train TD3_BC without interacting with the environment (offline)"""
        if self.resume:
            self.load_agent_checkpoint()
        else:
            # delete tensorboard log file
            log_tools.del_all_files_in_dir(self.result_dir)

        while self.train_step < (int(self.max_train_step)):
            actor_loss, critic_loss1, critic_loss2 = self.train()

            if self.train_step % self.log_interval == 0:
                self.store_agent_checkpoint()
                self.tensorboard_writer.log_train_data(
                    {
                        "actor_loss": actor_loss,
                        "critic_loss1": critic_loss1,
                        "critic_loss2": critic_loss2
                    }, self.train_step)

            if self.eval_freq > 0 and self.train_step % self.eval_freq == 0:
                avg_reward, avg_length = evaluate(agent=self, episode_num=10)
                self.tensorboard_writer.log_eval_data(
                    {
                        "eval_episode_length": avg_length,
                        "eval_episode_reward": avg_reward
                    }, self.train_step)
Exemplo n.º 3
0
    def learn(self):
        if self.resume:
            self.load_agent_checkpoint()
        else:
            # delete tensorboard log file
            log_tools.del_all_files_in_dir(self.result_dir)

        while self.train_step < (int(self.max_train_step)):
            # train
            q_loss1, q_loss2, actor_loss, alpha_prime_loss = self.train()

            if self.train_step % self.eval_freq == 0:
                avg_reward, avg_length = evaluate(agent=self, episode_num=10)
                self.tensorboard_writer.log_eval_data(
                    {
                        "eval_episode_length": avg_length,
                        "eval_episode_reward": avg_reward
                    }, self.train_step)

            if self.train_step % self.log_interval == 0:
                self.store_agent_checkpoint()
                self.tensorboard_writer.log_train_data(
                    {
                        "q_loss_1": q_loss1,
                        "q_loss_2": q_loss2,
                        "actor_loss": actor_loss,
                        "alpha_prime_loss": alpha_prime_loss
                    }, self.train_step)
Exemplo n.º 4
0
    def learn(self):
        """Train PLAS without interacting with the environment (offline)"""
        if self.resume:
            self.load_agent_checkpoint()
        else:
            # delete tensorboard log file
            log_tools.del_all_files_in_dir(self.result_dir)

        # Train CVAE before train agent
        print(
            "==============================Start to train CVAE=============================="
        )

        while self.cvae_iterations < (int(self.max_cvae_iterations)):
            cvae_loss = self.train_cvae()
            if self.cvae_iterations % 1000 == 0:
                print("CVAE iteration:", self.cvae_iterations, "\t",
                      "CVAE Loss:", cvae_loss)
                self.tensorboard_writer.log_train_data(
                    {"cvae_loss": cvae_loss}, self.cvae_iterations)

        # Train Agent
        print(
            "==============================Start to train Agent=============================="
        )
        while self.train_step < (int(self.max_train_step)):
            critic_loss, actor_loss = self.train()

            if self.train_step % self.eval_freq == 0:
                if self.train_step % self.eval_freq == 0:
                    avg_reward, avg_length = evaluate(agent=self,
                                                      episode_num=10)
                    self.tensorboard_writer.log_eval_data(
                        {
                            "eval_episode_length": avg_length,
                            "eval_episode_reward": avg_reward
                        }, self.train_step)

            if self.train_step % self.log_interval == 0:
                self.store_agent_checkpoint()
                self.tensorboard_writer.log_train_data(
                    {
                        "critic_loss": critic_loss,
                        "actor_loss": actor_loss
                    }, self.train_step)
Exemplo n.º 5
0
        replay_buffer = None
    else:
        replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                     act_dim=1,
                                     capacity=args.capacity,
                                     batch_size=args.batch_size)

    agent = DQN_Agent(env=env,
                      replay_buffer=replay_buffer,
                      Q_net=Q_net,
                      qf_lr=1e-4,
                      gamma=0.99,
                      initial_eps=0.1,
                      end_eps=0.001,
                      eps_decay_period=1000000,
                      eval_eps=0.001,
                      target_update_freq=1000,
                      train_interval=1,
                      explore_step=args.explore_step,
                      eval_freq=args.eval_freq,
                      max_train_step=args.max_train_step,
                      train_id=args.train_id,
                      log_interval=args.log_interval,
                      resume=args.resume,
                      device=args.device)

    if args.show:
        train_tools.evaluate(agent, 10, show=True)
    else:
        agent.learn()
Exemplo n.º 6
0
    def learn(self):
        if self.resume:
            self.load_agent_checkpoint()
        else:
            # delete tensorboard log file
            log_tools.del_all_files_in_dir(self.result_dir)
        print(
            "==============================start train==================================="
        )
        obs = self.env.reset()
        done = False

        episode_reward = 0
        episode_length = 0
        trajectory_length = 0

        while self.time_step < self.max_time_step:
            action, log_prob = self.choose_action(np.array(obs))
            next_obs, reward, done, info = self.env.step(action)
            value = self.critic_net(torch.tensor([obs],
                                                 dtype=torch.float32)).item()
            episode_reward += reward
            self.trajectory_buffer.add(obs, action, reward, done, log_prob,
                                       value)
            obs = next_obs
            episode_length += 1
            trajectory_length += 1
            self.time_step += 1

            if done:
                obs = self.env.reset()
                self.episode_num += 1

                print(
                    f"Time Step: {self.time_step} Episode Num: {self.episode_num} "
                    f"Episode Length: {episode_length} Episode Reward: {episode_reward:.3f}"
                )
                self.tensorboard_writer.log_learn_data(
                    {
                        "episode_length": episode_length,
                        "episode_reward": episode_reward
                    }, self.time_step)
                episode_reward = 0
                episode_length = 0

            if trajectory_length == self.trajectory_length:
                last_val = self.critic_net(
                    torch.tensor([obs],
                                 dtype=torch.float32)).item() if done else 0
                self.trajectory_buffer.finish_path(
                    last_val=last_val,
                    gamma=self.gamma,
                    gae_lambda=self.gae_lambda,
                    gae_normalize=self.gae_normalize)
                actor_loss, critic_loss = self.train()
                trajectory_length = 0

            if self.time_step % self.log_interval == 0:
                self.store_agent_checkpoint()
                self.tensorboard_writer.log_train_data(
                    {
                        "actor_loss": actor_loss,
                        "critic_loss": critic_loss
                    }, self.time_step)

            if self.eval_freq > 0 and self.time_step % self.eval_freq == 0:
                avg_reward, avg_length = evaluate(agent=self, episode_num=10)
                self.tensorboard_writer.log_eval_data(
                    {
                        "eval_episode_length": avg_length,
                        "eval_episode_reward": avg_reward
                    }, self.time_step)