Beispiel #1
0
    def __init__(self,
                 env: str,
                 gamma: float = 0.99,
                 lr: float = 0.01,
                 batch_size: int = 8,
                 n_steps: int = 10,
                 avg_reward_len: int = 100,
                 entropy_beta: float = 0.01,
                 epoch_len: int = 1000,
                 **kwargs) -> None:
        """
        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            batch_episodes: how many episodes to rollout for each batch of training
            entropy_beta: dictates the level of entropy per batch
            avg_reward_len: how many episodes to take into account when calculating the avg reward
            epoch_len: how many batches before pseudo epoch
        """
        super().__init__()

        if not _GYM_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError(
                "This Module requires gym environment which is not installed yet."
            )

        # Hyperparameters
        self.lr = lr
        self.batch_size = batch_size
        self.batches_per_epoch = self.batch_size * epoch_len
        self.entropy_beta = entropy_beta
        self.gamma = gamma
        self.n_steps = n_steps

        self.save_hyperparameters()

        # Model components
        self.env = gym.make(env)
        self.net = MLP(self.env.observation_space.shape,
                       self.env.action_space.n)
        self.agent = PolicyAgent(self.net)

        # Tracking metrics
        self.total_rewards = []
        self.episode_rewards = []
        self.done_episodes = 0
        self.avg_rewards = 0
        self.avg_reward_len = avg_reward_len
        self.eps = np.finfo(np.float32).eps.item()
        self.batch_states = []
        self.batch_actions = []

        self.state = self.env.reset()
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.xp_stream = EpisodicExperienceStream(self.env,
                                                  self.agent,
                                                  Mock(),
                                                  episodes=4)
        self.rl_dataloader = DataLoader(self.xp_stream)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQN.add_model_specific_args(parent_parser)
        args_list = [
            "--algo",
            "dqn",
            "--warm_start_steps",
            "500",
            "--episode_length",
            "100",
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = Reinforce(**vars(self.hparams))
Beispiel #3
0
    def build_networks(self) -> None:
        """Initializes the SAC policy and q networks (with targets)"""
        action_bias = torch.from_numpy((self.env.action_space.high + self.env.action_space.low) / 2)
        action_scale = torch.from_numpy((self.env.action_space.high - self.env.action_space.low) / 2)
        self.policy = ContinuousMLP(self.obs_shape, self.n_actions, action_bias=action_bias, action_scale=action_scale)

        concat_shape = [self.obs_shape[0] + self.n_actions]
        self.q1 = MLP(concat_shape, 1)
        self.q2 = MLP(concat_shape, 1)
        self.target_q1 = MLP(concat_shape, 1)
        self.target_q2 = MLP(concat_shape, 1)
        self.target_q1.load_state_dict(self.q1.state_dict())
        self.target_q2.load_state_dict(self.q2.state_dict())
Beispiel #4
0
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser)
        args_list = [
            "--env", "CartPole-v0",
            "--batch_size", "32"
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = VanillaPolicyGradient(**vars(self.hparams))
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.exp_source = DiscountedExperienceSource(self.env, self.agent)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = Reinforce.add_model_specific_args(parent_parser)
        args_list = [
            "--env", "CartPole-v0", "--batch_size", "32", "--gamma", "0.99"
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = Reinforce(**vars(self.hparams))

        self.rl_dataloader = self.model.train_dataloader()
class VanillaPolicyGradient(pl.LightningModule):
    """
    PyTorch Lightning implementation of `Vanilla Policy Gradient
    <https://papers.nips.cc/paper/
    1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
    Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour
    Model implemented by:

        - `Donal Byrne <https://github.com/djbyrne>`

    Example:
        >>> from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient
        ...
        >>> model = VanillaPolicyGradient("CartPole-v0")

    Train::
        trainer = Trainer()
        trainer.fit(model)

    Note:
        This example is based on:
        https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/04_cartpole_pg.py

    Note:
        Currently only supports CPU and single GPU training with `distributed_backend=dp`
    """
    def __init__(self,
                 env: str,
                 gamma: float = 0.99,
                 lr: float = 0.01,
                 batch_size: int = 8,
                 n_steps: int = 10,
                 avg_reward_len: int = 100,
                 entropy_beta: float = 0.01,
                 epoch_len: int = 1000,
                 **kwargs) -> None:
        """
        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            batch_episodes: how many episodes to rollout for each batch of training
            entropy_beta: dictates the level of entropy per batch
            avg_reward_len: how many episodes to take into account when calculating the avg reward
        """
        super().__init__()

        if not _GYM_AVAILABLE:
            raise ModuleNotFoundError(
                'This Module requires gym environment which is not installed yet.'
            )

        # Hyperparameters
        self.lr = lr
        self.batch_size = batch_size
        self.batches_per_epoch = self.batch_size * epoch_len
        self.entropy_beta = entropy_beta
        self.gamma = gamma
        self.n_steps = n_steps

        self.save_hyperparameters()

        # Model components
        self.env = gym.make(env)
        self.net = MLP(self.env.observation_space.shape,
                       self.env.action_space.n)
        self.agent = PolicyAgent(self.net)

        # Tracking metrics
        self.total_rewards = []
        self.episode_rewards = []
        self.done_episodes = 0
        self.avg_rewards = 0
        self.avg_reward_len = avg_reward_len
        self.eps = np.finfo(np.float32).eps.item()
        self.batch_states = []
        self.batch_actions = []

        self.state = self.env.reset()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Passes in a state x through the network and gets the q_values of each action as an output

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.net(x)
        return output

    def train_batch(
        self,
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
        """
        Contains the logic for generating a new batch of data to be passed to the DataLoader
        Returns:
            yields a tuple of Lists containing tensors for states, actions and rewards of the batch.
        """

        while True:

            action = self.agent(self.state, self.device)

            next_state, reward, done, _ = self.env.step(action[0])

            self.episode_rewards.append(reward)
            self.batch_actions.append(action)
            self.batch_states.append(self.state)
            self.state = next_state

            if done:
                self.done_episodes += 1
                self.state = self.env.reset()
                self.total_rewards.append(sum(self.episode_rewards))
                self.avg_rewards = float(
                    np.mean(self.total_rewards[-self.avg_reward_len:]))

                returns = self.compute_returns(self.episode_rewards)

                for idx in range(len(self.batch_actions)):
                    yield self.batch_states[idx], self.batch_actions[
                        idx], returns[idx]

                self.batch_states = []
                self.batch_actions = []
                self.episode_rewards = []

    def compute_returns(self, rewards):
        """
        Calculate the discounted rewards of the batched rewards

        Args:
            rewards: list of batched rewards

        Returns:
            list of discounted rewards
        """
        reward = 0
        returns = []

        for r in rewards[::-1]:
            reward = r + self.gamma * reward
            returns.insert(0, reward)

        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + self.eps)

        return returns

    def loss(self, states, actions, scaled_rewards) -> torch.Tensor:
        """
        Calculates the loss for VPG

        Args:
            states: batched states
            actions: batch actions
            scaled_rewards: batche Q values

        Returns:
            loss for the current batch
        """

        logits = self.net(states)

        # policy loss
        log_prob = log_softmax(logits, dim=1)
        log_prob_actions = scaled_rewards * log_prob[range(self.batch_size),
                                                     actions[0]]
        policy_loss = -log_prob_actions.mean()

        # entropy loss
        prob = softmax(logits, dim=1)
        entropy = -(prob * log_prob).sum(dim=1).mean()
        entropy_loss = -self.entropy_beta * entropy

        # total loss
        loss = policy_loss + entropy_loss

        return loss

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor],
                      _) -> OrderedDict:
        """
        Carries out a single step through the environment to update the replay buffer.
        Then calculates loss based on the minibatch recieved

        Args:
            batch: current mini batch of replay data
            _: batch number, not used

        Returns:
            Training loss and log metrics
        """
        states, actions, scaled_rewards = batch

        loss = self.loss(states, actions, scaled_rewards)

        log = {
            "episodes": self.done_episodes,
            "reward": self.total_rewards[-1],
            "avg_reward": self.avg_rewards,
        }
        return OrderedDict({
            "loss": loss,
            "avg_reward": self.avg_rewards,
            "log": log,
            "progress_bar": log,
        })

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        return [optimizer]

    def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = ExperienceSourceDataset(self.train_batch)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self._dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch"""
        return batch[0][0][0].device.index if self.on_gpu else "cpu"

    @staticmethod
    def add_model_specific_args(arg_parser) -> argparse.ArgumentParser:
        """
        Adds arguments for DQN model

        Note: these params are fine tuned for Pong env

        Args:
            arg_parser: the current argument parser to add to

        Returns:
            arg_parser with model specific cargs added
        """

        arg_parser.add_argument("--entropy_beta",
                                type=float,
                                default=0.01,
                                help="entropy value")
        arg_parser.add_argument("--batches_per_epoch",
                                type=int,
                                default=10000,
                                help="number of batches in an epoch")
        arg_parser.add_argument("--batch_size",
                                type=int,
                                default=32,
                                help="size of the batches")
        arg_parser.add_argument("--lr",
                                type=float,
                                default=1e-3,
                                help="learning rate")
        arg_parser.add_argument("--env",
                                type=str,
                                required=True,
                                help="gym environment tag")
        arg_parser.add_argument("--gamma",
                                type=float,
                                default=0.99,
                                help="discount factor")
        arg_parser.add_argument("--seed",
                                type=int,
                                default=123,
                                help="seed for training run")

        arg_parser.add_argument(
            "--avg_reward_len",
            type=int,
            default=100,
            help="how many episodes to include in avg reward",
        )

        return arg_parser
Beispiel #7
0
    def __init__(
        self,
        env: str,
        gamma: float = 0.99,
        lr: float = 0.01,
        batch_size: int = 8,
        n_steps: int = 10,
        avg_reward_len: int = 100,
        entropy_beta: float = 0.01,
        epoch_len: int = 1000,
        num_batch_episodes: int = 4,
        **kwargs
    ) -> None:
        """
        PyTorch Lightning implementation of `REINFORCE
        <https://papers.nips.cc/paper/
        1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
        Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour
        Model implemented by:

            - `Donal Byrne <https://github.com/djbyrne>`

        Example:
            >>> from pl_bolts.models.rl.reinforce_model import Reinforce
            ...
            >>> model = Reinforce("CartPole-v0")

        Train::

            trainer = Trainer()
            trainer.fit(model)

        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            n_steps: number of stakes per discounted experience
            entropy_beta: entropy coefficient
            epoch_len: how many batches before pseudo epoch
            num_batch_episodes: how many episodes to rollout for each batch of training
            avg_reward_len: how many episodes to take into account when calculating the avg reward

        Note:
            This example is based on:
            https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/02_cartpole_reinforce.py

        Note:
            Currently only supports CPU and single GPU training with `distributed_backend=dp`
        """
        super().__init__()

        if not _GYM_AVAILABLE:
            raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.')

        # Hyperparameters
        self.lr = lr
        self.batch_size = batch_size
        self.batches_per_epoch = self.batch_size * epoch_len
        self.entropy_beta = entropy_beta
        self.gamma = gamma
        self.n_steps = n_steps
        self.num_batch_episodes = num_batch_episodes

        self.save_hyperparameters()

        # Model components
        self.env = gym.make(env)
        self.net = MLP(self.env.observation_space.shape, self.env.action_space.n)
        self.agent = PolicyAgent(self.net)

        # Tracking metrics
        self.total_steps = 0
        self.total_rewards = [0]
        self.done_episodes = 0
        self.avg_rewards = 0
        self.reward_sum = 0.0
        self.batch_episodes = 0
        self.avg_reward_len = avg_reward_len

        self.batch_states = []
        self.batch_actions = []
        self.batch_qvals = []
        self.cur_rewards = []

        self.state = self.env.reset()
Beispiel #8
0
class Reinforce(pl.LightningModule):

    def __init__(
        self,
        env: str,
        gamma: float = 0.99,
        lr: float = 0.01,
        batch_size: int = 8,
        n_steps: int = 10,
        avg_reward_len: int = 100,
        num_envs: int = 1,
        entropy_beta: float = 0.01,
        epoch_len: int = 1000,
        num_batch_episodes: int = 4,
        **kwargs
    ) -> None:
        """
        PyTorch Lightning implementation of `REINFORCE
        <https://papers.nips.cc/paper/
        1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
        Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour
        Model implemented by:

            - `Donal Byrne <https://github.com/djbyrne>`

        Example:
            >>> from pl_bolts.models.rl.reinforce_model import Reinforce
            ...
            >>> model = Reinforce("PongNoFrameskip-v4")

        Train::

            trainer = Trainer()
            trainer.fit(model)

        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            batch_episodes: how many episodes to rollout for each batch of training
            avg_reward_len: how many episodes to take into account when calculating the avg reward

        Note:
            This example is based on:
            https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/02_cartpole_reinforce.py

        Note:
            Currently only supports CPU and single GPU training with `distributed_backend=dp`
        """
        super().__init__()

        # Hyperparameters
        self.lr = lr
        self.batch_size = batch_size * num_envs
        self.batches_per_epoch = self.batch_size * epoch_len
        self.entropy_beta = entropy_beta
        self.gamma = gamma
        self.n_steps = n_steps
        self.num_batch_episodes = num_batch_episodes

        self.save_hyperparameters()

        # Model components
        self.env = [gym.make(env) for _ in range(num_envs)]
        self.net = MLP(self.env[0].observation_space.shape, self.env[0].action_space.n)
        self.agent = PolicyAgent(self.net)
        self.exp_source = DiscountedExperienceSource(
            self.env, self.agent, gamma=gamma, n_steps=self.n_steps
        )

        # Tracking metrics
        self.total_steps = 0
        self.total_rewards = [0]
        self.done_episodes = 0
        self.avg_rewards = 0
        self.reward_sum = 0.0
        self.batch_episodes = 0
        self.avg_reward_len = avg_reward_len

        self.batch_states = []
        self.batch_actions = []
        self.batch_qvals = []
        self.cur_rewards = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Passes in a state x through the network and gets the q_values of each action as an output

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.net(x)
        return output

    def calc_qvals(self, rewards: List[float]) -> List[float]:
        """Calculate the discounted rewards of all rewards in list

        Args:
            rewards: list of rewards from latest batch

        Returns:
            list of discounted rewards
        """
        assert isinstance(rewards[0], float)

        cumul_reward = []
        sum_r = 0.0

        for r in reversed(rewards):
            sum_r = (sum_r * self.gamma) + r
            cumul_reward.append(sum_r)

        return list(reversed(cumul_reward))

    def train_batch(
        self,
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
        """
        Contains the logic for generating a new batch of data to be passed to the DataLoader

        Yield:
            yields a tuple of Lists containing tensors for states, actions and rewards of the batch.
        """
        for exp in self.exp_source.runner(self.device):

            self.batch_states.append(exp.state)
            self.batch_actions.append(exp.action)
            self.cur_rewards.append(exp.reward)

            # Check if episode is completed and update trackers
            if exp.done:
                self.batch_qvals.extend(self.calc_qvals(self.cur_rewards))
                self.cur_rewards.clear()
                self.batch_episodes += 1

            # Check if episodes have finished and use total reward
            new_rewards = self.exp_source.pop_total_rewards()
            if new_rewards:
                for reward in new_rewards:
                    self.done_episodes += 1
                    self.total_rewards.append(reward)
                    self.avg_rewards = float(
                        np.mean(self.total_rewards[-self.avg_reward_len:])
                    )

            self.total_steps += 1

            if self.batch_episodes >= self.num_batch_episodes:
                for state, action, qval in zip(
                    self.batch_states, self.batch_actions, self.batch_qvals
                ):
                    yield state, action, qval

                self.batch_episodes = 0

            # Simulates epochs
            if self.total_steps % self.batches_per_epoch == 0:
                break

    def loss(self, states, actions, scaled_rewards) -> torch.Tensor:
        logits = self.net(states)

        # policy loss
        log_prob = log_softmax(logits, dim=1)
        log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions]
        loss = -log_prob_actions.mean()

        return loss

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict:
        """
        Carries out a single step through the environment to update the replay buffer.
        Then calculates loss based on the minibatch recieved

        Args:
            batch: current mini batch of replay data
            _: batch number, not used

        Returns:
            Training loss and log metrics
        """
        states, actions, scaled_rewards = batch

        loss = self.loss(states, actions, scaled_rewards)

        log = {
            "episodes": self.done_episodes,
            "reward": self.total_rewards[-1],
            "avg_reward": self.avg_rewards,
        }

        return OrderedDict(
            {
                "loss": loss,
                "avg_reward": self.avg_rewards,
                "log": log,
                "progress_bar": log,
            }
        )

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        return [optimizer]

    def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = ExperienceSourceDataset(self.train_batch)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self._dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch"""
        return batch[0][0][0].device.index if self.on_gpu else "cpu"

    @staticmethod
    def add_model_specific_args(arg_parser) -> argparse.ArgumentParser:
        """
        Adds arguments for DQN model

        Note: these params are fine tuned for Pong env

        Args:
            arg_parser: the current argument parser to add to

        Returns:
            arg_parser with model specific cargs added
        """

        arg_parser.add_argument(
            "--entropy_beta", type=float, default=0.01, help="entropy value",
        )

        return arg_parser
    def __init__(self,
                 env: str,
                 gamma: float = 0.99,
                 lr: float = 0.01,
                 batch_size: int = 8,
                 n_steps: int = 10,
                 avg_reward_len: int = 100,
                 entropy_beta: float = 0.01,
                 epoch_len: int = 1000,
                 num_batch_episodes: int = 4,
                 **kwargs) -> None:
        """
        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            n_steps: number of stakes per discounted experience
            entropy_beta: entropy coefficient
            epoch_len: how many batches before pseudo epoch
            num_batch_episodes: how many episodes to rollout for each batch of training
            avg_reward_len: how many episodes to take into account when calculating the avg reward
        """
        super().__init__()

        if not _GYM_AVAILABLE:
            raise ModuleNotFoundError(
                'This Module requires gym environment which is not installed yet.'
            )

        # Hyperparameters
        self.lr = lr
        self.batch_size = batch_size
        self.batches_per_epoch = self.batch_size * epoch_len
        self.entropy_beta = entropy_beta
        self.gamma = gamma
        self.n_steps = n_steps
        self.num_batch_episodes = num_batch_episodes

        self.save_hyperparameters()

        # Model components
        self.env = gym.make(env)
        self.net = MLP(self.env.observation_space.shape,
                       self.env.action_space.n)
        self.agent = PolicyAgent(self.net)

        # Tracking metrics
        self.total_steps = 0
        self.total_rewards = [0]
        self.done_episodes = 0
        self.avg_rewards = 0
        self.reward_sum = 0.0
        self.batch_episodes = 0
        self.avg_reward_len = avg_reward_len

        self.batch_states = []
        self.batch_actions = []
        self.batch_qvals = []
        self.cur_rewards = []

        self.state = self.env.reset()
class Reinforce(LightningModule):
    r"""PyTorch Lightning implementation of REINFORCE_.

    Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour

    Model implemented by:

        - `Donal Byrne <https://github.com/djbyrne>`

    Example:
        >>> from pl_bolts.models.rl.reinforce_model import Reinforce
        ...
        >>> model = Reinforce("CartPole-v0")

    Train::

        trainer = Trainer()
        trainer.fit(model)

    Note:
        This example is based on:
        https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/02_cartpole_reinforce.py

    Note:
        Currently only supports CPU and single GPU training with `accelerator=dp`

    .. _REINFORCE:
        https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf
    """
    def __init__(self,
                 env: str,
                 gamma: float = 0.99,
                 lr: float = 0.01,
                 batch_size: int = 8,
                 n_steps: int = 10,
                 avg_reward_len: int = 100,
                 entropy_beta: float = 0.01,
                 epoch_len: int = 1000,
                 num_batch_episodes: int = 4,
                 **kwargs) -> None:
        """
        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            n_steps: number of stakes per discounted experience
            entropy_beta: entropy coefficient
            epoch_len: how many batches before pseudo epoch
            num_batch_episodes: how many episodes to rollout for each batch of training
            avg_reward_len: how many episodes to take into account when calculating the avg reward
        """
        super().__init__()

        if not _GYM_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError(
                "This Module requires gym environment which is not installed yet."
            )

        # Hyperparameters
        self.lr = lr
        self.batch_size = batch_size
        self.batches_per_epoch = self.batch_size * epoch_len
        self.entropy_beta = entropy_beta
        self.gamma = gamma
        self.n_steps = n_steps
        self.num_batch_episodes = num_batch_episodes

        self.save_hyperparameters()

        # Model components
        self.env = gym.make(env)
        self.net = MLP(self.env.observation_space.shape,
                       self.env.action_space.n)
        self.agent = PolicyAgent(self.net)

        # Tracking metrics
        self.total_steps = 0
        self.total_rewards = [0]
        self.done_episodes = 0
        self.avg_rewards = 0
        self.reward_sum = 0.0
        self.batch_episodes = 0
        self.avg_reward_len = avg_reward_len

        self.batch_states = []
        self.batch_actions = []
        self.batch_qvals = []
        self.cur_rewards = []

        self.state = self.env.reset()

    def forward(self, x: Tensor) -> Tensor:
        """Passes in a state x through the network and gets the q_values of each action as an output.

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.net(x)
        return output

    def calc_qvals(self, rewards: List[float]) -> List[float]:
        """Calculate the discounted rewards of all rewards in list.

        Args:
            rewards: list of rewards from latest batch

        Returns:
            list of discounted rewards
        """
        assert isinstance(rewards[0], float)

        cumul_reward = []
        sum_r = 0.0

        for r in reversed(rewards):
            sum_r = (sum_r * self.gamma) + r
            cumul_reward.append(sum_r)

        return list(reversed(cumul_reward))

    def discount_rewards(self, experiences: Tuple[Experience]) -> float:
        """Calculates the discounted reward over N experiences.

        Args:
            experiences: Tuple of Experience

        Returns:
            total discounted reward
        """
        total_reward = 0.0
        for exp in reversed(experiences):
            total_reward = (self.gamma * total_reward) + exp.reward
        return total_reward

    def train_batch(self, ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
        """Contains the logic for generating a new batch of data to be passed to the DataLoader.

        Yield:
            yields a tuple of Lists containing tensors for states, actions and rewards of the batch.
        """

        while True:

            action = self.agent(self.state, self.device)

            next_state, reward, done, _ = self.env.step(action[0])

            self.batch_states.append(self.state)
            self.batch_actions.append(action[0])
            self.cur_rewards.append(reward)

            self.state = next_state
            self.total_steps += 1

            if done:
                self.batch_qvals.extend(self.calc_qvals(self.cur_rewards))
                self.batch_episodes += 1
                self.done_episodes += 1
                self.total_rewards.append(sum(self.cur_rewards))
                self.avg_rewards = float(
                    np.mean(self.total_rewards[-self.avg_reward_len:]))
                self.cur_rewards = []
                self.state = self.env.reset()

            if self.batch_episodes >= self.num_batch_episodes:
                for state, action, qval in zip(self.batch_states,
                                               self.batch_actions,
                                               self.batch_qvals):
                    yield state, action, qval

                self.batch_episodes = 0

                self.batch_states.clear()
                self.batch_actions.clear()
                self.batch_qvals.clear()

            # Simulates epochs
            if self.total_steps % self.batches_per_epoch == 0:
                break

    def loss(self, states, actions, scaled_rewards) -> Tensor:
        logits = self.net(states)

        # policy loss
        log_prob = log_softmax(logits, dim=1)
        log_prob_actions = scaled_rewards * log_prob[range(len(log_prob)),
                                                     actions]
        loss = -log_prob_actions.mean()

        return loss

    def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
        """Carries out a single step through the environment to update the replay buffer. Then calculates loss
        based on the minibatch recieved.

        Args:
            batch: current mini batch of replay data
            _: batch number, not used

        Returns:
            Training loss and log metrics
        """
        states, actions, scaled_rewards = batch

        loss = self.loss(states, actions, scaled_rewards)

        log = {
            "episodes": self.done_episodes,
            "reward": self.total_rewards[-1],
            "avg_reward": self.avg_rewards,
        }

        return OrderedDict({
            "loss": loss,
            "avg_reward": self.avg_rewards,
            "log": log,
            "progress_bar": log,
        })

    def configure_optimizers(self) -> List[Optimizer]:
        """Initialize Adam optimizer."""
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        return [optimizer]

    def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences."""
        dataset = ExperienceSourceDataset(self.train_batch)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader."""
        return self._dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch."""
        return batch[0][0][0].device.index if self.on_gpu else "cpu"

    @staticmethod
    def add_model_specific_args(arg_parser) -> argparse.ArgumentParser:
        """Adds arguments for DQN model.

        Note:
            These params are fine tuned for Pong env.

        Args:
            arg_parser: the current argument parser to add to

        Returns:
            arg_parser with model specific cargs added
        """
        arg_parser.add_argument("--batches_per_epoch",
                                type=int,
                                default=10000,
                                help="number of batches in an epoch")
        arg_parser.add_argument("--batch_size",
                                type=int,
                                default=32,
                                help="size of the batches")
        arg_parser.add_argument("--lr",
                                type=float,
                                default=1e-3,
                                help="learning rate")

        arg_parser.add_argument("--env",
                                type=str,
                                required=True,
                                help="gym environment tag")
        arg_parser.add_argument("--gamma",
                                type=float,
                                default=0.99,
                                help="discount factor")

        arg_parser.add_argument(
            "--avg_reward_len",
            type=int,
            default=100,
            help="how many episodes to include in avg reward",
        )

        arg_parser.add_argument(
            "--entropy_beta",
            type=float,
            default=0.01,
            help="entropy value",
        )

        return arg_parser
Beispiel #11
0
class Reinforce(pl.LightningModule):
    """ Basic REINFORCE Policy Model """
    def __init__(self,
                 env: str,
                 gamma: float = 0.99,
                 lr: float = 1e-4,
                 batch_size: int = 32,
                 batch_episodes: int = 4,
                 **kwargs) -> None:
        """
        PyTorch Lightning implementation of `REINFORCE
        <https://papers.nips.cc/paper/
        1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_

        Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour

        Model implemented by:

            - `Donal Byrne <https://github.com/djbyrne>`

        Example:

            >>> from pl_bolts.models.rl.reinforce_model import Reinforce
            ...
            >>> model = Reinforce("PongNoFrameskip-v4")

        Train::

            trainer = Trainer()
            trainer.fit(model)

        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            batch_episodes: how many episodes to rollout for each batch of training

        .. note::
            This example is based on:
             https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition\
             /blob/master/Chapter11/02_cartpole_reinforce.py

        .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp`

        """
        super().__init__()

        # self.env = wrappers.make_env(self.hparams.env)    # use for Atari
        self.env = ToTensor(gym.make(env))  # use for Box2D/Control
        self.env.seed(123)

        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n

        self.net = None
        self.build_networks()

        self.agent = PolicyAgent(self.net)

        self.gamma = gamma
        self.lr = lr
        self.batch_size = batch_size
        self.batch_episodes = batch_episodes

        self.total_reward = 0
        self.episode_reward = 0
        self.episode_count = 0
        self.episode_steps = 0
        self.total_episode_steps = 0

        self.reward_list = []
        for _ in range(100):
            self.reward_list.append(0)
        self.avg_reward = 0

    def build_networks(self) -> None:
        """Initializes the DQN train and target networks"""
        self.net = MLP(self.obs_shape, self.n_actions)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Passes in a state x through the network and gets the q_values of each action as an output

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.net(x)
        return output

    def calc_qvals(self, rewards: List[List]) -> List[List]:
        """
        Takes in the rewards for each batched episode and returns list of qvals for each batched episode

        Args:
            rewards: list of rewards for each episodes in the batch

        Returns:
            List of qvals for each episodes
        """
        res = []
        sum_r = 0.0
        for reward in reversed(rewards):
            sum_r *= self.gamma
            sum_r += reward
            res.append(deepcopy(sum_r))
        return list(reversed(res))

    def process_batch(
        self, batch: List[List[Experience]]
    ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
        """
        Takes in a batch of episodes and retrieves the q vals, the states and the actions for the batch

        Args:
            batch: list of episodes, each containing a list of Experiences

        Returns:
            q_vals, states and actions used for calculating the loss
        """
        # get outputs for each episode
        batch_rewards, batch_states, batch_actions = [], [], []
        for episode in batch:
            ep_rewards, ep_states, ep_actions = [], [], []

            # log the outputs for each step
            for step in episode:
                ep_rewards.append(step[2].float())
                ep_states.append(step[0])
                ep_actions.append(step[1])

            # add episode outputs to the batch
            batch_rewards.append(ep_rewards)
            batch_states.append(ep_states)
            batch_actions.append(ep_actions)

        # get qvals
        batch_qvals = []
        for reward in batch_rewards:
            batch_qvals.append(self.calc_qvals(reward))

        # flatten the batched outputs
        batch_actions, batch_qvals, batch_rewards, batch_states = self.flatten_batch(
            batch_actions, batch_qvals, batch_rewards, batch_states)

        return batch_qvals, batch_states, batch_actions, batch_rewards

    @staticmethod
    def flatten_batch(
        batch_actions: List[List[Tensor]],
        batch_qvals: List[List[Tensor]],
        batch_rewards: List[List[Tensor]],
        batch_states: List[List[Tuple[Tensor, Tensor]]],
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        Takes in the outputs of the processed batch and flattens the several episodes into a single tensor for each
        batched output

        Args:
            batch_actions: actions taken in each batch episodes
            batch_qvals: Q vals for each batch episode
            batch_rewards: reward for each batch episode
            batch_states: states for each batch episodes

        Returns:
            The input batched results flattend into a single tensor
        """
        # flatten all episode steps into a single list
        batch_qvals = list(chain.from_iterable(batch_qvals))
        batch_states = list(chain.from_iterable(batch_states))
        batch_actions = list(chain.from_iterable(batch_actions))
        batch_rewards = list(chain.from_iterable(batch_rewards))

        # stack steps into single tensor and remove extra dimension
        batch_qvals = torch.stack(batch_qvals).squeeze()
        batch_states = torch.stack(batch_states).squeeze()
        batch_actions = torch.stack(batch_actions).squeeze()
        batch_rewards = torch.stack(batch_rewards).squeeze()

        return batch_actions, batch_qvals, batch_rewards, batch_states

    def loss(
        self,
        batch_qvals: List[Tensor],
        batch_states: List[Tensor],
        batch_actions: List[Tensor],
    ) -> torch.Tensor:
        """
        Calculates the mse loss using a batch of states, actions and Q values from several episodes. These have all
        been flattend into a single tensor.

        Args:
            batch_qvals: current mini batch of q values
            batch_actions: current batch of actions
            batch_states: current batch of states

        Returns:
            loss
        """
        logits = self.net(batch_states)
        log_prob = log_softmax(logits, dim=1)
        log_prob_actions = (batch_qvals *
                            log_prob[range(len(batch_states)), batch_actions])
        loss = -log_prob_actions.mean()
        return loss

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor],
                      _) -> OrderedDict:
        """
        Carries out a single step through the environment to update the replay buffer.
        Then calculates loss based on the minibatch recieved

        Args:
            batch: current mini batch of replay data
            _: batch number, not used

        Returns:
            Training loss and log metrics
        """
        device = self.get_device(batch)

        batch_qvals, batch_states, batch_actions, batch_rewards = self.process_batch(
            batch)

        # get avg reward over the batched episodes
        self.episode_reward = sum(batch_rewards) / len(batch)
        self.reward_list.append(self.episode_reward)
        self.avg_reward = sum(self.reward_list) / len(self.reward_list)

        # calculates training loss
        loss = self.loss(batch_qvals, batch_states, batch_actions)

        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss = loss.unsqueeze(0)

        self.episode_count += self.batch_episodes

        log = {
            "episode_reward": torch.tensor(self.episode_reward).to(device),
            "train_loss": loss,
            "avg_reward": self.avg_reward,
        }
        status = {
            "steps": torch.tensor(self.global_step).to(device),
            "episode_reward": torch.tensor(self.episode_reward).to(device),
            "episodes": torch.tensor(self.episode_count),
            "avg_reward": self.avg_reward,
        }

        self.episode_reward = 0

        return OrderedDict({
            "loss": loss,
            "reward": self.avg_reward,
            "log": log,
            "progress_bar": status,
        })

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        return [optimizer]

    def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = EpisodicExperienceStream(self.env,
                                           self.agent,
                                           self.device,
                                           episodes=self.batch_episodes)
        dataloader = DataLoader(dataset=dataset)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self._dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch"""
        return batch[0][0][0].device.index if self.on_gpu else "cpu"

    @staticmethod
    def add_model_specific_args(arg_parser) -> argparse.ArgumentParser:
        """
        Adds arguments for DQN model

        Note: these params are fine tuned for Pong env

        Args:
            arg_parser: the current argument parser to add to

        Returns:
            arg_parser with model specific cargs added
        """

        arg_parser.add_argument(
            "--batch_episodes",
            type=int,
            default=4,
            help="how many episodes to run per batch",
        )

        return arg_parser
Beispiel #12
0
    def __init__(
        self,
        env: str,
        gamma: float = 0.99,
        lam: float = 0.95,
        lr_actor: float = 3e-4,
        lr_critic: float = 1e-3,
        max_episode_len: float = 200,
        batch_size: int = 512,
        steps_per_epoch: int = 2048,
        nb_optim_iters: int = 4,
        clip_ratio: float = 0.2,
        **kwargs: Any,
    ) -> None:
        """
        Args:
            env: gym environment tag
            gamma: discount factor
            lam: advantage discount factor (lambda in the paper)
            lr_actor: learning rate of actor network
            lr_critic: learning rate of critic network
            max_episode_len: maximum number interactions (actions) in an episode
            batch_size:  batch_size when training network- can simulate number of policy updates performed per epoch
            steps_per_epoch: how many action-state pairs to rollout for trajectory collection per epoch
            nb_optim_iters: how many steps of gradient descent to perform on each batch
            clip_ratio: hyperparameter for clipping in the policy objective
        """
        super().__init__()

        if not _GYM_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError(
                "This Module requires gym environment which is not installed yet."
            )

        # Hyperparameters
        self.lr_actor = lr_actor
        self.lr_critic = lr_critic
        self.steps_per_epoch = steps_per_epoch
        self.nb_optim_iters = nb_optim_iters
        self.batch_size = batch_size
        self.gamma = gamma
        self.lam = lam
        self.max_episode_len = max_episode_len
        self.clip_ratio = clip_ratio
        self.save_hyperparameters()

        self.env = gym.make(env)
        # value network
        self.critic = MLP(self.env.observation_space.shape, 1)
        # policy network (agent)
        if isinstance(self.env.action_space, gym.spaces.box.Box):
            act_dim = self.env.action_space.shape[0]
            actor_mlp = MLP(self.env.observation_space.shape, act_dim)
            self.actor = ActorContinous(actor_mlp, act_dim)
        elif isinstance(self.env.action_space, gym.spaces.discrete.Discrete):
            actor_mlp = MLP(self.env.observation_space.shape,
                            self.env.action_space.n)
            self.actor = ActorCategorical(actor_mlp)
        else:
            raise NotImplementedError(
                "Env action space should be of type Box (continous) or Discrete (categorical). "
                f"Got type: {type(self.env.action_space)}")

        self.batch_states = []
        self.batch_actions = []
        self.batch_adv = []
        self.batch_qvals = []
        self.batch_logp = []

        self.ep_rewards = []
        self.ep_values = []
        self.epoch_rewards = []

        self.episode_step = 0
        self.avg_ep_reward = 0
        self.avg_ep_len = 0
        self.avg_reward = 0

        self.state = torch.FloatTensor(self.env.reset())
Beispiel #13
0
class PPO(LightningModule):
    """PyTorch Lightning implementation of `Proximal Policy Optimization.

    <https://arxiv.org/abs/1707.06347>`_

    Paper authors: John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov

    Model implemented by:
        `Sidhant Sundrani <https://github.com/sid-sundrani>`_

    Example:
        >>> from pl_bolts.models.rl.ppo_model import PPO
        >>> model = PPO("CartPole-v0")

    Note:
        This example is based on OpenAI's
        `PPO <https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/ppo/ppo.py>`_ and
        `PPO2 <https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py>`_.

    Note:
        Currently only supports CPU and single GPU training with ``accelerator=dp``
    """
    def __init__(
        self,
        env: str,
        gamma: float = 0.99,
        lam: float = 0.95,
        lr_actor: float = 3e-4,
        lr_critic: float = 1e-3,
        max_episode_len: float = 200,
        batch_size: int = 512,
        steps_per_epoch: int = 2048,
        nb_optim_iters: int = 4,
        clip_ratio: float = 0.2,
        **kwargs: Any,
    ) -> None:
        """
        Args:
            env: gym environment tag
            gamma: discount factor
            lam: advantage discount factor (lambda in the paper)
            lr_actor: learning rate of actor network
            lr_critic: learning rate of critic network
            max_episode_len: maximum number interactions (actions) in an episode
            batch_size:  batch_size when training network- can simulate number of policy updates performed per epoch
            steps_per_epoch: how many action-state pairs to rollout for trajectory collection per epoch
            nb_optim_iters: how many steps of gradient descent to perform on each batch
            clip_ratio: hyperparameter for clipping in the policy objective
        """
        super().__init__()

        if not _GYM_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError(
                "This Module requires gym environment which is not installed yet."
            )

        # Hyperparameters
        self.lr_actor = lr_actor
        self.lr_critic = lr_critic
        self.steps_per_epoch = steps_per_epoch
        self.nb_optim_iters = nb_optim_iters
        self.batch_size = batch_size
        self.gamma = gamma
        self.lam = lam
        self.max_episode_len = max_episode_len
        self.clip_ratio = clip_ratio
        self.save_hyperparameters()

        self.env = gym.make(env)
        # value network
        self.critic = MLP(self.env.observation_space.shape, 1)
        # policy network (agent)
        if isinstance(self.env.action_space, gym.spaces.box.Box):
            act_dim = self.env.action_space.shape[0]
            actor_mlp = MLP(self.env.observation_space.shape, act_dim)
            self.actor = ActorContinous(actor_mlp, act_dim)
        elif isinstance(self.env.action_space, gym.spaces.discrete.Discrete):
            actor_mlp = MLP(self.env.observation_space.shape,
                            self.env.action_space.n)
            self.actor = ActorCategorical(actor_mlp)
        else:
            raise NotImplementedError(
                "Env action space should be of type Box (continous) or Discrete (categorical). "
                f"Got type: {type(self.env.action_space)}")

        self.batch_states = []
        self.batch_actions = []
        self.batch_adv = []
        self.batch_qvals = []
        self.batch_logp = []

        self.ep_rewards = []
        self.ep_values = []
        self.epoch_rewards = []

        self.episode_step = 0
        self.avg_ep_reward = 0
        self.avg_ep_len = 0
        self.avg_reward = 0

        self.state = torch.FloatTensor(self.env.reset())

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """Passes in a state x through the network and returns the policy and a sampled action.

        Args:
            x: environment state

        Returns:
            Tuple of policy and action
        """
        pi, action = self.actor(x)
        value = self.critic(x)

        return pi, action, value

    def discount_rewards(self, rewards: List[float],
                         discount: float) -> List[float]:
        """Calculate the discounted rewards of all rewards in list.

        Args:
            rewards: list of rewards/advantages
            discount: discount factor

        Returns:
            list of discounted rewards/advantages
        """
        assert isinstance(rewards[0], float)

        cumul_reward = []
        sum_r = 0.0

        for r in reversed(rewards):
            sum_r = (sum_r * discount) + r
            cumul_reward.append(sum_r)

        return list(reversed(cumul_reward))

    def calc_advantage(self, rewards: List[float], values: List[float],
                       last_value: float) -> List[float]:
        """Calculate the advantage given rewards, state values, and the last value of episode.

        Args:
            rewards: list of episode rewards
            values: list of state values from critic
            last_value: value of last state of episode

        Returns:
            list of advantages
        """
        rews = rewards + [last_value]
        vals = values + [last_value]
        # GAE
        delta = [
            rews[i] + self.gamma * vals[i + 1] - vals[i]
            for i in range(len(rews) - 1)
        ]
        adv = self.discount_rewards(delta, self.gamma * self.lam)

        return adv

    def generate_trajectory_samples(
            self) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
        """Contains the logic for generating trajectory data to train policy and value network.

        Yield:
           Tuple of Lists containing tensors for states, actions, log probs, qvals and advantage
        """

        for step in range(self.steps_per_epoch):
            self.state = self.state.to(device=self.device)

            with torch.no_grad():
                pi, action, value = self(self.state)
                log_prob = self.actor.get_log_prob(pi, action)

            next_state, reward, done, _ = self.env.step(action.cpu().numpy())

            self.episode_step += 1

            self.batch_states.append(self.state)
            self.batch_actions.append(action)
            self.batch_logp.append(log_prob)

            self.ep_rewards.append(reward)
            self.ep_values.append(value.item())

            self.state = torch.FloatTensor(next_state)

            epoch_end = step == (self.steps_per_epoch - 1)
            terminal = len(self.ep_rewards) == self.max_episode_len

            if epoch_end or done or terminal:
                # if trajectory ends abtruptly, boostrap value of next state
                if (terminal or epoch_end) and not done:
                    self.state = self.state.to(device=self.device)
                    with torch.no_grad():
                        _, _, value = self(self.state)
                        last_value = value.item()
                        steps_before_cutoff = self.episode_step
                else:
                    last_value = 0
                    steps_before_cutoff = 0

                # discounted cumulative reward
                self.batch_qvals += self.discount_rewards(
                    self.ep_rewards + [last_value], self.gamma)[:-1]
                # advantage
                self.batch_adv += self.calc_advantage(self.ep_rewards,
                                                      self.ep_values,
                                                      last_value)
                # logs
                self.epoch_rewards.append(sum(self.ep_rewards))
                # reset params
                self.ep_rewards = []
                self.ep_values = []
                self.episode_step = 0
                self.state = torch.FloatTensor(self.env.reset())

            if epoch_end:
                train_data = zip(self.batch_states, self.batch_actions,
                                 self.batch_logp, self.batch_qvals,
                                 self.batch_adv)

                for state, action, logp_old, qval, adv in train_data:
                    yield state, action, logp_old, qval, adv

                self.batch_states.clear()
                self.batch_actions.clear()
                self.batch_adv.clear()
                self.batch_logp.clear()
                self.batch_qvals.clear()

                # logging
                self.avg_reward = sum(
                    self.epoch_rewards) / self.steps_per_epoch

                # if epoch ended abruptly, exlude last cut-short episode to prevent stats skewness
                epoch_rewards = self.epoch_rewards
                if not done:
                    epoch_rewards = epoch_rewards[:-1]

                total_epoch_reward = sum(epoch_rewards)
                nb_episodes = len(epoch_rewards)

                self.avg_ep_reward = total_epoch_reward / nb_episodes
                self.avg_ep_len = (self.steps_per_epoch -
                                   steps_before_cutoff) / nb_episodes

                self.epoch_rewards.clear()

    def actor_loss(self, state, action, logp_old, adv) -> Tensor:
        pi, _ = self.actor(state)
        logp = self.actor.get_log_prob(pi, action)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - self.clip_ratio,
                               1 + self.clip_ratio) * adv
        loss_actor = -(torch.min(ratio * adv, clip_adv)).mean()
        return loss_actor

    def critic_loss(self, state, qval) -> Tensor:
        value = self.critic(state)
        loss_critic = (qval - value).pow(2).mean()
        return loss_critic

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx,
                      optimizer_idx):
        """Carries out a single update to actor and critic network from a batch of replay buffer.

        Args:
            batch: batch of replay buffer/trajectory data
            batch_idx: not used
            optimizer_idx: idx that controls optimizing actor or critic network

        Returns:
            loss
        """
        state, action, old_logp, qval, adv = batch

        # normalize advantages
        adv = (adv - adv.mean()) / adv.std()

        self.log("avg_ep_len",
                 self.avg_ep_len,
                 prog_bar=True,
                 on_step=False,
                 on_epoch=True)
        self.log("avg_ep_reward",
                 self.avg_ep_reward,
                 prog_bar=True,
                 on_step=False,
                 on_epoch=True)
        self.log("avg_reward",
                 self.avg_reward,
                 prog_bar=True,
                 on_step=False,
                 on_epoch=True)

        if optimizer_idx == 0:
            loss_actor = self.actor_loss(state, action, old_logp, adv)
            self.log("loss_actor",
                     loss_actor,
                     on_step=False,
                     on_epoch=True,
                     prog_bar=True,
                     logger=True)

            return loss_actor

        if optimizer_idx == 1:
            loss_critic = self.critic_loss(state, qval)
            self.log("loss_critic",
                     loss_critic,
                     on_step=False,
                     on_epoch=True,
                     prog_bar=False,
                     logger=True)

            return loss_critic

        raise NotImplementedError(
            f"Got optimizer_idx: {optimizer_idx}. Expected only 2 optimizers from configure_optimizers. "
            "Modify optimizer logic in training_step to account for this. ")

    def configure_optimizers(self) -> List[Optimizer]:
        """Initialize Adam optimizer."""
        optimizer_actor = torch.optim.Adam(self.actor.parameters(),
                                           lr=self.lr_actor)
        optimizer_critic = torch.optim.Adam(self.critic.parameters(),
                                            lr=self.lr_critic)

        return optimizer_actor, optimizer_critic

    def optimizer_step(self, *args, **kwargs):
        """Run ``nb_optim_iters`` number of iterations of gradient descent on actor and critic for each data
        sample."""
        for _ in range(self.nb_optim_iters):
            super().optimizer_step(*args, **kwargs)

    def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences."""
        dataset = ExperienceSourceDataset(self.generate_trajectory_samples)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader."""
        return self._dataloader()

    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no cover
        parser = argparse.ArgumentParser(parents=[parent_parser])
        parser.add_argument("--env", type=str, default="CartPole-v0")
        parser.add_argument("--gamma",
                            type=float,
                            default=0.99,
                            help="discount factor")
        parser.add_argument("--lam",
                            type=float,
                            default=0.95,
                            help="advantage discount factor")
        parser.add_argument("--lr_actor",
                            type=float,
                            default=3e-4,
                            help="learning rate of actor network")
        parser.add_argument("--lr_critic",
                            type=float,
                            default=1e-3,
                            help="learning rate of critic network")
        parser.add_argument("--max_episode_len",
                            type=int,
                            default=1000,
                            help="capacity of the replay buffer")
        parser.add_argument("--batch_size",
                            type=int,
                            default=512,
                            help="batch_size when training network")
        parser.add_argument(
            "--steps_per_epoch",
            type=int,
            default=2048,
            help=
            "how many action-state pairs to rollout for trajectory collection per epoch",
        )
        parser.add_argument(
            "--nb_optim_iters",
            type=int,
            default=4,
            help="how many steps of gradient descent to perform on each batch")
        parser.add_argument(
            "--clip_ratio",
            type=float,
            default=0.2,
            help="hyperparameter for clipping in the policy objective")

        return parser
Beispiel #14
0
class SAC(LightningModule):
    def __init__(
        self,
        env: str,
        eps_start: float = 1.0,
        eps_end: float = 0.02,
        eps_last_frame: int = 150000,
        sync_rate: int = 1,
        gamma: float = 0.99,
        policy_learning_rate: float = 3e-4,
        q_learning_rate: float = 3e-4,
        target_alpha: float = 5e-3,
        batch_size: int = 128,
        replay_size: int = 1000000,
        warm_start_size: int = 10000,
        avg_reward_len: int = 100,
        min_episode_reward: int = -21,
        seed: int = 123,
        batches_per_epoch: int = 10000,
        n_steps: int = 1,
        **kwargs,
    ):
        super().__init__()

        # Environment
        self.env = gym.make(env)
        self.test_env = gym.make(env)

        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.shape[0]

        # Model Attributes
        self.buffer = None
        self.dataset = None

        self.policy = None
        self.q1 = None
        self.q2 = None
        self.target_q1 = None
        self.target_q2 = None
        self.build_networks()

        self.agent = SoftActorCriticAgent(self.policy)

        # Hyperparameters
        self.save_hyperparameters()

        # Metrics
        self.total_episode_steps = [0]
        self.total_rewards = [0]
        self.done_episodes = 0
        self.total_steps = 0

        # Average Rewards
        self.avg_reward_len = avg_reward_len

        for _ in range(avg_reward_len):
            self.total_rewards.append(
                torch.tensor(min_episode_reward, device=self.device))

        self.avg_rewards = float(
            np.mean(self.total_rewards[-self.avg_reward_len:]))

        self.state = self.env.reset()

        self.automatic_optimization = False

    def run_n_episodes(self, env, n_epsiodes: int = 1) -> List[int]:
        """Carries out N episodes of the environment with the current agent without exploration.

        Args:
            env: environment to use, either train environment or test environment
            n_epsiodes: number of episodes to run
        """
        total_rewards = []

        for _ in range(n_epsiodes):
            episode_state = env.reset()
            done = False
            episode_reward = 0

            while not done:
                action = self.agent.get_action(episode_state, self.device)
                next_state, reward, done, _ = env.step(action[0])
                episode_state = next_state
                episode_reward += reward

            total_rewards.append(episode_reward)

        return total_rewards

    def populate(self, warm_start: int) -> None:
        """Populates the buffer with initial experience."""
        if warm_start > 0:
            self.state = self.env.reset()

            for _ in range(warm_start):
                action = self.agent(self.state, self.device)
                next_state, reward, done, _ = self.env.step(action[0])
                exp = Experience(state=self.state,
                                 action=action[0],
                                 reward=reward,
                                 done=done,
                                 new_state=next_state)
                self.buffer.append(exp)
                self.state = next_state

                if done:
                    self.state = self.env.reset()

    def build_networks(self) -> None:
        """Initializes the SAC policy and q networks (with targets)"""
        action_bias = torch.from_numpy(
            (self.env.action_space.high + self.env.action_space.low) / 2)
        action_scale = torch.from_numpy(
            (self.env.action_space.high - self.env.action_space.low) / 2)
        self.policy = ContinuousMLP(self.obs_shape,
                                    self.n_actions,
                                    action_bias=action_bias,
                                    action_scale=action_scale)

        concat_shape = [self.obs_shape[0] + self.n_actions]
        self.q1 = MLP(concat_shape, 1)
        self.q2 = MLP(concat_shape, 1)
        self.target_q1 = MLP(concat_shape, 1)
        self.target_q2 = MLP(concat_shape, 1)
        self.target_q1.load_state_dict(self.q1.state_dict())
        self.target_q2.load_state_dict(self.q2.state_dict())

    def soft_update_target(self, q_net, target_net):
        """Update the weights in target network using a weighted sum.

        w_target := (1-a) * w_target + a * w_q

        Args:
            q_net: the critic (q) network
            target_net: the target (q) network
        """
        for q_param, target_param in zip(q_net.parameters(),
                                         target_net.parameters()):
            target_param.data.copy_((1.0 - self.hparams.target_alpha) *
                                    target_param.data +
                                    self.hparams.target_alpha * q_param)

    def forward(self, x: Tensor) -> Tensor:
        """Passes in a state x through the network and gets the q_values of each action as an output.

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.policy(x).sample()
        return output

    def train_batch(self, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        """Contains the logic for generating a new batch of data to be passed to the DataLoader.

        Returns:
            yields a Experience tuple containing the state, action, reward, done and next_state.
        """
        episode_reward = 0
        episode_steps = 0

        while True:
            self.total_steps += 1
            action = self.agent(self.state, self.device)

            next_state, r, is_done, _ = self.env.step(action[0])

            episode_reward += r
            episode_steps += 1

            exp = Experience(state=self.state,
                             action=action[0],
                             reward=r,
                             done=is_done,
                             new_state=next_state)

            self.buffer.append(exp)
            self.state = next_state

            if is_done:
                self.done_episodes += 1
                self.total_rewards.append(episode_reward)
                self.total_episode_steps.append(episode_steps)
                self.avg_rewards = float(
                    np.mean(self.total_rewards[-self.avg_reward_len:]))
                self.state = self.env.reset()
                episode_steps = 0
                episode_reward = 0

            states, actions, rewards, dones, new_states = self.buffer.sample(
                self.hparams.batch_size)

            for idx, _ in enumerate(dones):
                yield states[idx], actions[idx], rewards[idx], dones[
                    idx], new_states[idx]

            # Simulates epochs
            if self.total_steps % self.hparams.batches_per_epoch == 0:
                break

    def loss(
        self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Calculates the loss for SAC which contains a total of 3 losses.

        Args:
            batch: a batch of states, actions, rewards, dones, and next states
        """
        states, actions, rewards, dones, next_states = batch
        rewards = rewards.unsqueeze(-1)
        dones = dones.float().unsqueeze(-1)

        # actor
        dist = self.policy(states)
        new_actions, new_logprobs = dist.rsample_and_log_prob()
        new_logprobs = new_logprobs.unsqueeze(-1)

        new_states_actions = torch.cat((states, new_actions), 1)
        new_q1_values = self.q1(new_states_actions)
        new_q2_values = self.q2(new_states_actions)
        new_qmin_values = torch.min(new_q1_values, new_q2_values)

        policy_loss = (new_logprobs - new_qmin_values).mean()

        # critic
        states_actions = torch.cat((states, actions), 1)
        q1_values = self.q1(states_actions)
        q2_values = self.q2(states_actions)

        with torch.no_grad():
            next_dist = self.policy(next_states)
            new_next_actions, new_next_logprobs = next_dist.rsample_and_log_prob(
            )
            new_next_logprobs = new_next_logprobs.unsqueeze(-1)

            new_next_states_actions = torch.cat(
                (next_states, new_next_actions), 1)
            next_q1_values = self.target_q1(new_next_states_actions)
            next_q2_values = self.target_q2(new_next_states_actions)
            next_qmin_values = torch.min(next_q1_values,
                                         next_q2_values) - new_next_logprobs
            target_values = rewards + (
                1.0 - dones) * self.hparams.gamma * next_qmin_values

        q1_loss = F.mse_loss(q1_values, target_values)
        q2_loss = F.mse_loss(q2_values, target_values)

        return policy_loss, q1_loss, q2_loss

    def training_step(self, batch: Tuple[Tensor, Tensor], _, optimizer_idx):
        """Carries out a single step through the environment to update the replay buffer. Then calculates loss
        based on the minibatch recieved.

        Args:
            batch: current mini batch of replay data
            _: batch number, not used
            optimizer_idx: not used
        """
        policy_optim, q1_optim, q2_optim = self.optimizers()
        policy_loss, q1_loss, q2_loss = self.loss(batch)

        policy_optim.zero_grad()
        self.manual_backward(policy_loss)
        policy_optim.step()

        q1_optim.zero_grad()
        self.manual_backward(q1_loss)
        q1_optim.step()

        q2_optim.zero_grad()
        self.manual_backward(q2_loss)
        q2_optim.step()

        # Soft update of target network
        if self.global_step % self.hparams.sync_rate == 0:
            self.soft_update_target(self.q1, self.target_q1)
            self.soft_update_target(self.q2, self.target_q2)

        self.log_dict({
            "total_reward": self.total_rewards[-1],
            "avg_reward": self.avg_rewards,
            "policy_loss": policy_loss,
            "q1_loss": q1_loss,
            "q2_loss": q2_loss,
            "episodes": self.done_episodes,
            "episode_steps": self.total_episode_steps[-1],
        })

    def test_step(self, *args, **kwargs) -> Dict[str, Tensor]:
        """Evaluate the agent for 10 episodes."""
        test_reward = self.run_n_episodes(self.test_env, 1)
        avg_reward = sum(test_reward) / len(test_reward)
        return {"test_reward": avg_reward}

    def test_epoch_end(self, outputs) -> Dict[str, Tensor]:
        """Log the avg of the test results."""
        rewards = [x["test_reward"] for x in outputs]
        avg_reward = sum(rewards) / len(rewards)
        self.log("avg_test_reward", avg_reward)
        return {"avg_test_reward": avg_reward}

    def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences."""
        self.buffer = MultiStepBuffer(self.hparams.replay_size,
                                      self.hparams.n_steps)
        self.populate(self.hparams.warm_start_size)

        self.dataset = ExperienceSourceDataset(self.train_batch)
        return DataLoader(dataset=self.dataset,
                          batch_size=self.hparams.batch_size)

    def train_dataloader(self) -> DataLoader:
        """Get train loader."""
        return self._dataloader()

    def test_dataloader(self) -> DataLoader:
        """Get test loader."""
        return self._dataloader()

    def configure_optimizers(self) -> Tuple[Optimizer]:
        """Initialize Adam optimizer."""
        policy_optim = optim.Adam(self.policy.parameters(),
                                  self.hparams.policy_learning_rate)
        q1_optim = optim.Adam(self.q1.parameters(),
                              self.hparams.q_learning_rate)
        q2_optim = optim.Adam(self.q2.parameters(),
                              self.hparams.q_learning_rate)
        return policy_optim, q1_optim, q2_optim

    @staticmethod
    def add_model_specific_args(
        arg_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser:
        """Adds arguments for DQN model.

        Note:
            These params are fine tuned for Pong env.

        Args:
            arg_parser: parent parser
        """
        arg_parser.add_argument(
            "--sync_rate",
            type=int,
            default=1,
            help="how many frames do we update the target network",
        )
        arg_parser.add_argument(
            "--replay_size",
            type=int,
            default=1000000,
            help="capacity of the replay buffer",
        )
        arg_parser.add_argument(
            "--warm_start_size",
            type=int,
            default=10000,
            help=
            "how many samples do we use to fill our buffer at the start of training",
        )
        arg_parser.add_argument("--batches_per_epoch",
                                type=int,
                                default=10000,
                                help="number of batches in an epoch")
        arg_parser.add_argument("--batch_size",
                                type=int,
                                default=128,
                                help="size of the batches")
        arg_parser.add_argument("--policy_lr",
                                type=float,
                                default=3e-4,
                                help="policy learning rate")
        arg_parser.add_argument("--q_lr",
                                type=float,
                                default=3e-4,
                                help="q learning rate")
        arg_parser.add_argument("--env",
                                type=str,
                                required=True,
                                help="gym environment tag")
        arg_parser.add_argument("--gamma",
                                type=float,
                                default=0.99,
                                help="discount factor")

        arg_parser.add_argument(
            "--avg_reward_len",
            type=int,
            default=100,
            help="how many episodes to include in avg reward",
        )
        arg_parser.add_argument(
            "--n_steps",
            type=int,
            default=1,
            help="how many frames do we update the target network",
        )

        return arg_parser
 def build_networks(self) -> None:
     """Initializes the DQN train and target networks"""
     self.net = MLP(self.obs_shape, self.n_actions)
class PolicyGradient(pl.LightningModule):
    """ Vanilla Policy Gradient Model """
    def __init__(self,
                 env: str,
                 gamma: float = 0.99,
                 lr: float = 1e-4,
                 batch_size: int = 32,
                 entropy_beta: float = 0.01,
                 batch_episodes: int = 4,
                 *args,
                 **kwargs) -> None:
        """
        PyTorch Lightning implementation of `Vanilla Policy Gradient
        <https://papers.nips.cc/paper/
        1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_

        Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour

        Model implemented by:

            - `Donal Byrne <https://github.com/djbyrne>`

        Example:

            >>> from pl_bolts.models.rl.vanilla_policy_gradient_model import PolicyGradient
            ...
            >>> model = PolicyGradient("PongNoFrameskip-v4")

        Train::

            trainer = Trainer()
            trainer.fit(model)

        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            batch_episodes: how many episodes to rollout for each batch of training
            entropy_beta: dictates the level of entropy per batch

        .. note::
            This example is based on:
             https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition\
             /blob/master/Chapter11/04_cartpole_pg.py

        .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp`

        """
        super().__init__()

        # self.env = wrappers.make_env(self.hparams.env)    # use for Atari
        self.env = ToTensor(gym.make(env))  # use for Box2D/Control
        self.env.seed(123)

        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n

        self.net = None
        self.build_networks()

        self.agent = PolicyAgent(self.net)
        self.source = NStepExperienceSource(env=self.env,
                                            agent=self.agent,
                                            n_steps=10)

        self.gamma = gamma
        self.lr = lr
        self.batch_size = batch_size
        self.batch_episodes = batch_episodes
        self.entropy_beta = entropy_beta
        self.baseline = 0

        # Metrics

        self.reward_sum = 0
        self.env_steps = 0
        self.total_steps = 0
        self.total_reward = 0
        self.episode_count = 0

        self.reward_list = []
        for _ in range(100):
            self.reward_list.append(torch.tensor(0, device=self.device))
        self.avg_reward = 0

    def build_networks(self) -> None:
        """Initializes the DQN train and target networks"""
        self.net = MLP(self.obs_shape, self.n_actions)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Passes in a state x through the network and gets the q_values of each action as an output

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.net(x)
        return output

    def calc_qvals(self, rewards: List[Tensor]) -> List[Tensor]:
        """
        Takes in the rewards for each batched episode and returns list of qvals for each batched episode

        Args:
            rewards: list of rewards for each episodes in the batch

        Returns:
            List of qvals for each episodes
        """
        res = []
        sum_r = 0.0
        for reward in reversed(rewards):
            sum_r *= self.gamma
            sum_r += reward
            res.append(deepcopy(sum_r))
        res = list(reversed(res))
        # Subtract the mean (baseline) from the q_vals to reduce the high variance
        sum_q = 0
        for rew in res:
            sum_q += rew
        mean_q = sum_q / len(res)
        return [q - mean_q for q in res]

    def loss(
        self,
        batch_scales: List[Tensor],
        batch_states: List[Tensor],
        batch_actions: List[Tensor],
    ) -> torch.Tensor:
        """
        Calculates the mse loss using a batch of states, actions and Q values from several episodes. These have all
        been flattend into a single tensor.

        Args:
            batch_scales: current mini batch of rewards minus the baseline
            batch_actions: current batch of actions
            batch_states: current batch of states

        Returns:
            loss
        """
        logits = self.net(batch_states)

        log_prob, policy_loss = self.calc_policy_loss(batch_actions,
                                                      batch_scales,
                                                      batch_states, logits)

        entropy_loss_v = self.calc_entropy_loss(log_prob, logits)

        loss = policy_loss + entropy_loss_v

        return loss

    def calc_entropy_loss(self, log_prob: Tensor, logits: Tensor) -> Tensor:
        """
        Calculates the entropy to be added to the loss function
        Args:
            log_prob: log probabilities for each action
            logits: the raw outputs of the network

        Returns:
            entropy penalty for each state
        """
        prob_v = softmax(logits, dim=1)
        entropy_v = -(prob_v * log_prob).sum(dim=1).mean()
        entropy_loss_v = -self.entropy_beta * entropy_v
        return entropy_loss_v

    @staticmethod
    def calc_policy_loss(batch_actions: Tensor, batch_qvals: Tensor,
                         batch_states: Tensor,
                         logits: Tensor) -> Tuple[List, Tensor]:
        """
        Calculate the policy loss give the batch outputs and logits
        Args:
            batch_actions: actions from batched episodes
            batch_qvals: Q values from batched episodes
            batch_states: states from batched episodes
            logits: raw output of the network given the batch_states

        Returns:
            policy loss
        """
        log_prob = log_softmax(logits, dim=1)
        log_prob_actions = (batch_qvals *
                            log_prob[range(len(batch_states)), batch_actions])
        policy_loss = -log_prob_actions.mean()
        return log_prob, policy_loss

    def train_batch(
        self
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
        """
        Contains the logic for generating a new batch of data to be passed to the DataLoader

        Returns:
            yields a tuple of Lists containing tensors for states, actions and rewards of the batch.
        """

        for _ in range(self.batch_size):

            # take a step in the env
            exp, reward, done = self.source.step(self.device)
            self.env_steps += 1
            self.total_steps += 1

            # update the baseline
            self.reward_sum += exp.reward
            self.baseline = self.reward_sum / self.total_steps
            self.total_reward += reward

            # gather the experience data
            scale = exp.reward - self.baseline
            yield exp.new_state, exp.action, scale

            if done:
                # tracking metrics
                self.episode_count += 1
                self.reward_list.append(self.total_reward)
                self.avg_reward = sum(self.reward_list[-100:]) / 100

                self.logger.experiment.add_scalar("reward", self.total_reward,
                                                  self.total_steps)

                # reset metrics
                self.total_reward = 0
                self.env_steps = 0

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor],
                      _) -> OrderedDict:
        """
        Carries out a single step through the environment to update the replay buffer.
        Then calculates loss based on the minibatch recieved

        Args:
            batch: current mini batch of replay data
            _: batch number, not used

        Returns:
            Training loss and log metrics
        """
        states, actions, scales = batch

        # calculates training loss
        loss = self.loss(scales, states, actions)

        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss = loss.unsqueeze(0)

        log = {
            "train_loss": loss,
            "avg_reward": self.avg_reward,
            "episode_count": self.episode_count,
            "baseline": self.baseline
        }

        return OrderedDict({"loss": loss, "log": log, "progress_bar": log})

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        return [optimizer]

    def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = ExperienceSourceDataset(self.train_batch)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self._dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch"""
        return batch[0][0][0].device.index if self.on_gpu else "cpu"

    @staticmethod
    def add_model_specific_args(arg_parser) -> argparse.ArgumentParser:
        """
        Adds arguments for DQN model

        Note: these params are fine tuned for Pong env

        Args:
            parent
        """
        arg_parser.add_argument(
            "--batch_episodes",
            type=int,
            default=4,
            help="how episodes to run per batch",
        )
        arg_parser.add_argument("--entropy_beta",
                                type=int,
                                default=0.01,
                                help="entropy beta")
        return arg_parser
    def __init__(self,
                 env: str,
                 gamma: float = 0.99,
                 lr: float = 0.01,
                 batch_size: int = 8,
                 n_steps: int = 10,
                 avg_reward_len: int = 100,
                 num_envs: int = 4,
                 entropy_beta: float = 0.01,
                 epoch_len: int = 1000,
                 **kwargs) -> None:
        """
        PyTorch Lightning implementation of `Vanilla Policy Gradient
        <https://papers.nips.cc/paper/
        1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
        Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour
        Model implemented by:

            - `Donal Byrne <https://github.com/djbyrne>`

        Example:
            >>> from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient
            ...
            >>> model = VanillaPolicyGradient("PongNoFrameskip-v4")

        Train::

            trainer = Trainer()
            trainer.fit(model)

        Args:
            env: gym environment tag
            gamma: discount factor
            lr: learning rate
            batch_size: size of minibatch pulled from the DataLoader
            batch_episodes: how many episodes to rollout for each batch of training
            entropy_beta: dictates the level of entropy per batch
            avg_reward_len: how many episodes to take into account when calculating the avg reward

        Note:
            This example is based on:
            https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/04_cartpole_pg.py

        Note:
            Currently only supports CPU and single GPU training with `distributed_backend=dp`
        """
        super().__init__()

        # Hyperparameters
        self.lr = lr
        self.batch_size = batch_size * num_envs
        self.batches_per_epoch = self.batch_size * epoch_len
        self.entropy_beta = entropy_beta
        self.gamma = gamma
        self.n_steps = n_steps

        self.save_hyperparameters()

        # Model components
        self.env = [gym.make(env) for _ in range(num_envs)]
        self.net = MLP(self.env[0].observation_space.shape,
                       self.env[0].action_space.n)
        self.agent = PolicyAgent(self.net)
        self.exp_source = DiscountedExperienceSource(self.env,
                                                     self.agent,
                                                     gamma=gamma,
                                                     n_steps=self.n_steps)

        # Tracking metrics
        self.total_steps = 0
        self.total_rewards = [0]
        self.done_episodes = 0
        self.avg_rewards = 0
        self.reward_sum = 0.0
        self.baseline = 0
        self.avg_reward_len = avg_reward_len