示例#1
0
文件: base.py 项目: veds12/genrl
    def get_logging_params(self) -> Dict[str, Any]:
        """Gets relevant parameters for logging

        Returns:
            logs (:obj:`dict`): Logging parameters for monitoring training
        """
        logs = {
            "value_loss": safe_mean(self.logs["value_loss"]),
            "epsilon": safe_mean(self.logs["epsilon"]),
        }
        self.empty_logs()
        return logs
示例#2
0
    def get_logging_params(self) -> Dict[str, Any]:
        """Gets relevant parameters for logging

        Returns:
            logs (:obj:`dict`): Logging parameters for monitoring training
        """
        logs = {
            "policy_loss": safe_mean(self.logs["policy_loss"]),
            "value_loss": safe_mean(self.logs["value_loss"]),
            "policy_entropy": safe_mean(self.logs["policy_entropy"]),
            "mean_reward": safe_mean(self.rewards),
        }

        self.empty_logs()
        return logs
示例#3
0
    def log(self, timestep: int) -> None:
        """Helper function to log

        Sends useful parameters to the logger.

        Args:
            timestep (int): Current timestep of training
        """
        self.logger.write(
            {
                "timestep": timestep,
                "Episode": self.episodes,
                **self.agent.get_logging_params(),
                "Episode Reward": safe_mean(self.training_rewards),
            },
            self.log_key,
        )
        self.training_rewards = []
示例#4
0
    def train(self) -> None:
        """Main training method"""
        if self.load_model is not None:
            self.load()

        state, episode_len, episode = (
            self.env.reset(),
            np.zeros(self.env.n_envs),
            np.zeros(self.env.n_envs),
        )
        total_steps = self.max_ep_len * self.epochs * self.env.n_envs

        if "noise" in self.agent.__dict__ and self.agent.noise is not None:
            self.agent.noise.reset()

        assert self.update_interval % self.env.n_envs == 0

        self.rewards = []

        for timestep in range(0, total_steps, self.env.n_envs):
            self.agent.update_params_before_select_action(timestep)

            if timestep < self.warmup_steps:
                action = np.array(self.env.sample())
            else:
                action = self.agent.select_action(state)

            next_state, reward, done, _ = self.env.step(action)

            if self.render:
                self.env.render()

            episode_len += 1

            done = [
                False if episode_len[i] == self.max_ep_len else done[i]
                for i, ep_len in enumerate(episode_len)
            ]

            self.buffer.push((state, action, reward, next_state, done))
            state = next_state.copy()

            if np.any(done) or np.any(episode_len == self.max_ep_len):
                if "noise" in self.agent.__dict__ and self.agent.noise is not None:
                    self.agent.noise.reset()

                if sum(episode) % self.log_interval == 0:
                    self.logger.write(
                        {
                            "timestep":
                            timestep,
                            "Episode":
                            sum(episode),
                            **self.agent.get_logging_params(),
                            "Episode Reward":
                            safe_mean(self.rewards),
                        },
                        self.log_key,
                    )
                    self.rewards = []

                for i, di in enumerate(done):
                    if di:
                        self.rewards.append(self.env.episode_reward[i])
                        self.env.episode_reward[i] = 0
                        episode_len[i] = 0
                        episode[i] += 1

            if timestep >= self.start_update and timestep % self.update_interval == 0:
                self.agent.update_params(self.update_interval)

            if (timestep >= self.start_update and self.save_interval != 0
                    and timestep % self.save_interval == 0):
                self.save(timestep)

            if timestep >= self.max_timesteps:
                break

        self.env.close()
        self.logger.close()