Exemplo n.º 1
0
    def train(self,
              num_iterations: int,
              save_path: Optional[str] = None,
              disable_tqdm: bool = False,
              **collect_kwargs):

        print(f"Begin training, logged in {self.path}")
        timer = Timer()
        step_timer = Timer()

        # Store the first agent
        # saved_agents = [copy.deepcopy(self.agent.model.state_dict())]

        if save_path:
            torch.save(self.agent.model,
                       os.path.join(save_path, "base_agent.pt"))

        rewards = []

        for step in trange(num_iterations, disable=disable_tqdm):
            ########################################### Collect the data ###############################################
            timer.checkpoint()

            # data_batch = self.collector.collect_data(num_episodes=self.config["episodes"])
            data_batch = self.collector.collect_data(
                num_steps=self.config["steps"])
            data_time = timer.checkpoint()
            ############################################## Update policy ##############################################
            # Perform the PPO update
            metrics = self.ppo.train_on_data(data_batch,
                                             step,
                                             writer=self.writer)

            eval_batch = self.collector.collect_data(num_steps=1001)
            reward = eval_batch['rewards'].sum().item()
            rewards.append(reward)

            end_time = step_timer.checkpoint()

            if step % 500 == 0:
                print(
                    f"Current reward: {reward:.3f}, Avg over last 100 iterations: {np.mean(rewards[-100:]):.3f}"
                )

            # Save the agent to disk
            if save_path:
                torch.save(self.agent.model.state_dict(),
                           os.path.join(save_path, f"weights_{step + 1}"))

            # Write training time metrics to tensorboard
            time_metrics = {
                "agent/time_data": data_time,
                "agent/time_total": end_time,
                "agent/eval_reward": reward
            }

            write_dict(time_metrics, step, self.writer)

        return rewards
    def train(self,
              num_iterations: int,
              save_path: Optional[str] = None,
              disable_tqdm: bool = False,
              **collect_kwargs):

        print(f"Begin training, logged in {self.path}")
        timer = Timer()
        step_timer = Timer()

        # Store the first agent
        # saved_agents = [copy.deepcopy(self.agent.model.state_dict())]

        if save_path:
            for path, (agent_id, agent) in zip(self.agent_paths,
                                               self.agents.items()):
                torch.save(agent.model, os.path.join(str(path),
                                                     "base_agent.pt"))

        rewards = []

        for step in trange(num_iterations, disable=disable_tqdm):
            ########################################### Collect the data ###############################################
            timer.checkpoint()

            # data_batch = self.collector.collect_data(num_episodes=self.config["episodes"])
            data_batch = self.collector.collect_data(
                num_steps=self.config["steps"],
                gamma=self.config["gamma"],
                tau=self.config["tau"])
            data_time = timer.checkpoint()
            ############################################## Update policy ##############################################
            # Perform the PPO update
            metrics = self.ppo.train_on_data(data_batch,
                                             step,
                                             writer=self.writer)

            # eval_batch = self.collector.collect_data(num_steps=1001)
            # reward = eval_batch['rewards'].sum().item()
            # rewards.append(reward)

            end_time = step_timer.checkpoint()

            # Save the agent to disk
            if save_path:
                for path, (agent_id, agent) in zip(self.agent_paths,
                                                   self.agents.items()):
                    torch.save(agent.model.state_dict(),
                               os.path.join(str(path), f"weights_{step + 1}"))

            # Write training time metrics to tensorboard
            time_metrics = {
                "agent/time_data": data_time,
                "agent/time_total": end_time,
                # "agent/eval_reward": reward
            }

            write_dict(time_metrics, step, self.writer)
Exemplo n.º 3
0
    def train(self,
              num_iterations: int,
              starting_iteration: int = 0,
              disable_tqdm: bool = False,
              finish_episode: bool = False,
              divide_rewards: Optional[int] = None):

        timer = Timer()
        step_timer = Timer()
        for step in trange(starting_iteration,
                           starting_iteration + num_iterations,
                           disable=disable_tqdm):
            timer.checkpoint()
            data_batch = self.collector.collect_data(
                num_steps=self.config["batch_size"],
                finish_episode=finish_episode,
                divide_rewards=divide_rewards,
                preserve_channels=self.config["preserve_channels"])
            data_time = timer.checkpoint()
            time_metric = {
                f"{agent_id}/time_data_collection": data_time
                for agent_id in self.agent_ids
            }

            self.train_on_data(data_batch,
                               step,
                               extra_metrics=time_metric,
                               timer=timer)
            total_time = step_timer.checkpoint()
            self.write_dict(
                {
                    f"{agent_id}/time_total": total_time
                    for agent_id in self.agent_ids
                }, step)

            if step % 50 == 0:
                for agent_id, agent in self.agents.items():
                    torch.save(
                        agent,
                        os.path.join(str(self.path), f"{agent_id}_{step}.pt"))
    def train_on_data(self, data_batch: DataBatch,
                      step: int = 0,
                      writer: Optional[SummaryWriter] = None) -> Dict[str, float]:
        """
        Performs a single update step with PPO on the given batch of data.

        Args:
            data_batch: DataBatch, dictionary
            step:
            writer:

        Returns:

        """
        metrics = {}
        timer = Timer()

        entropy_coeff = self.config["entropy_coeff"]

        agent = self.agent
        optimizer = self.optimizer

        agent_batch = data_batch

        ####################################### Unpack and prepare the data #######################################

        if self.config["use_gpu"]:
            agent_batch = batch_to_gpu(agent_batch)
            agent.cuda()

        # Initialize metrics
        kl_divergence = 0.
        ppo_step = -1
        value_loss = torch.tensor(0)
        policy_loss = torch.tensor(0)
        loss = torch.tensor(0)

        batcher = Batcher(agent_batch['dones'].size(0) // self.config["minibatches"],
                          [np.arange(agent_batch['dones'].size(0))])

        # Start a timer
        timer.checkpoint()

        for ppo_step in range(self.config["ppo_steps"]):
            batcher.shuffle()

            # for indices, agent_minibatch in minibatches(agent_batch, self.config["batch_size"], shuffle=True):
            while not batcher.end():
                batch_indices = batcher.next_batch()[0]
                batch_indices = torch.tensor(batch_indices).long()

                agent_minibatch = index_data(agent_batch, batch_indices)
                # Evaluate again after the PPO step, for new values and gradients
                logprob_batch, value_batch, entropy_batch = agent.evaluate_actions(agent_minibatch)

                advantages_batch = agent_minibatch['advantages']
                old_logprobs_minibatch = agent_minibatch['logprobs']  # logprobs of taken actions
                discounted_batch = agent_minibatch['rewards_to_go']

                ######################################### Compute the loss #############################################
                # Surrogate loss
                prob_ratio = torch.exp(logprob_batch - old_logprobs_minibatch)
                surr1 = prob_ratio * advantages_batch
                surr2 = prob_ratio.clamp(1. - self.eps, 1 + self.eps) * advantages_batch
                # surr2 = torch.where(advantages_batch > 0,
                #                     (1. + self.eps) * advantages_batch,
                #                     (1. - self.eps) * advantages_batch)

                policy_loss = -torch.min(surr1, surr2)
                value_loss = 0.5 * (value_batch - discounted_batch) ** 2
                # import pdb; pdb.set_trace()
                loss = (torch.mean(policy_loss)
                        + (self.config["value_loss_coeff"] * torch.mean(value_loss))
                        - (entropy_coeff * torch.mean(entropy_batch)))

                ############################################# Update step ##############################################
                optimizer.zero_grad()
                loss.backward()
                if self.config["max_grad_norm"] is not None:
                    nn.utils.clip_grad_norm_(agent.model.parameters(), self.config["max_grad_norm"])
                optimizer.step()

            # logprob_batch, value_batch, entropy_batch = agent.evaluate_actions(agent_batch)
            #
            # kl_divergence = torch.mean(old_logprobs_batch - logprob_batch).item()
            # if abs(kl_divergence) > self.config["target_kl"]:
            #     break

        agent.cpu()

        # Training-related metrics
        metrics[f"agent/time_update"] = timer.checkpoint()
        metrics[f"agent/kl_divergence"] = kl_divergence
        metrics[f"agent/ppo_steps_made"] = ppo_step + 1
        metrics[f"agent/policy_loss"] = torch.mean(policy_loss).cpu().item()
        metrics[f"agent/value_loss"] = torch.mean(value_loss).cpu().item()
        metrics[f"agent/total_loss"] = loss.detach().cpu().item()
        metrics[f"agent/rewards"] = agent_batch['rewards'].cpu().sum().item()
        metrics[f"agent/mean_std"] = agent.model.std.mean().item()

        # Other metrics
        # metrics[f"agent/mean_entropy"] = torch.mean(entropy_batch).item()

        # Write the metrics to tensorboard
        write_dict(metrics, step, writer)

        return metrics