Esempio n. 1
0
    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)

        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()

        self.memory = Memory()

        self.episode = 0

        # useful data
        self._rewards = np.zeros(self.n_episodes)
        self._cumul_rewards = np.zeros(self.n_episodes)

        # default writer
        log_every = 5 * logger.getEffectiveLevel()
        self.writer = PeriodicWriter(self.name, log_every=log_every)
Esempio n. 2
0
    def reset(self, **kwargs):
        self.policy_net = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)

        self.policy_optimizer = optimizer_factory(self.policy_net.parameters(),
                                                  **self.optimizer_kwargs)

        self.memory = Memory()

        self.episode = 0
Esempio n. 3
0
    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)
        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()

        self.memory = Memory()

        self.episode = 0
Esempio n. 4
0
    def reset(self, **kwargs):
        self.policy_net = self.policy_net_fn(
            self.env,
            **self.policy_net_kwargs,
        ).to(self.device)

        self.policy_optimizer = optimizer_factory(self.policy_net.parameters(),
                                                  **self.optimizer_kwargs)

        self.memory = Memory()

        self.episode = 0

        # useful data
        self._rewards = np.zeros(self.n_episodes)
        self._cumul_rewards = np.zeros(self.n_episodes)

        # default writer
        log_every = 5 * logger.getEffectiveLevel()
        self.writer = PeriodicWriter(self.name, log_every=log_every)
Esempio n. 5
0
    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)
        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()  # TODO: turn into argument

        self.memory = Memory(
        )  # TODO: Improve memory to include returns and advantages
        self.returns = []  # TODO: add to memory
        self.advantages = []  # TODO: add to memory

        self.episode = 0
Esempio n. 6
0
class AVECPPOAgent(AgentWithSimplePolicy):
    """
    AVEC uses a modification of the training objective for the critic in
    actor-critic algorithms to better approximate the value function (critic).
    The new state-value function approximation learns the *relative* value of
    the states rather than their *absolute* value as in conventional
    actor-critic. This modification is:
    - well-motivated by recent studies [1,2];
    - theoretically sound;
    - intuitively supported by the need to improve the approximation error
    of the critic.

    The application of Actor with Variance Estimated Critic (AVEC) to
    state-of-the-art policy gradient methods produces considerable
    gains in performance (on average +26% for SAC and +40% for PPO)
    over the standard actor-critic training.

    Parameters
    ----------
    env : Model
        model with continuous (Box) state space and discrete actions
    batch_size : int
        Number of episodes to wait before updating the policy.
    horizon : int
        Horizon of the objective function. If None and gamma<1,
        set to 1/(1-gamma).
    gamma : double
        Discount factor in [0, 1]. If gamma is 1.0, the problem is set
        to be finite-horizon.
    entr_coef : double
        Entropy coefficient.
    vf_coef : double
        Value function loss coefficient.
    learning_rate : double
        Learning rate.
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    eps_clip : double
        PPO clipping range (epsilon).
    k_epochs : int
        Number of epochs per update.
    policy_net_fn : function(env, **kwargs)
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    value_net_fn : function(env, **kwargs)
        Function that returns an instance of a value network (pytorch).
        If None, a default net is used.
    policy_net_kwargs : dict
        kwargs for policy_net_fn
    value_net_kwargs : dict
        kwargs for value_net_fn
    use_bonus : bool, default = False
        If true, compute an 'exploration_bonus' and add it to the reward.
        See also UncertaintyEstimatorWrapper.
    uncertainty_estimator_kwargs : dict
        Arguments for the UncertaintyEstimatorWrapper
    device : str
        Device to put the tensors on

    References
    ----------
    Flet-Berliac, Y., Ouhamma, R., Maillard, O. A., & Preux, P. (2021).
    "Is Standard Deviation the New Standard? Revisiting the Critic in Deep
    Policy Gradients."
    In International Conference on Learning Representations.

    [1] Ilyas, A., Engstrom, L., Santurkar, S., Tsipras, D., Janoos, F.,
    Rudolph, L. & Madry, A. (2020).
    "A closer look at deep policy gradients."
    In International Conference on Learning Representations.

    [2] Tucker, G., Bhupatiraju, S., Gu, S., Turner, R., Ghahramani, Z. &
    Levine, S. (2018).
    "The mirage of action-dependent baselines in reinforcement learning."
    In International Conference on Machine Learning, pp. 5015–5024.
    """

    name = "AVECPPO"

    def __init__(self,
                 env,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.0,
                 avec_coef=1.0,
                 learning_rate=0.0003,
                 optimizer_type="ADAM",
                 eps_clip=0.2,
                 k_epochs=10,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 device="cuda:best",
                 **kwargs):
        # For all parameters, define self.param = param
        _, _, _, values = inspect.getargvalues(inspect.currentframe())
        values.pop("self")
        for arg, val in values.items():
            setattr(self, arg, val)

        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.use_bonus = use_bonus
        if self.use_bonus:
            self.env = UncertaintyEstimatorWrapper(
                self.env, **uncertainty_estimator_kwargs)

        self.device = choose_device(device)

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()

    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)
        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()

        self.memory = Memory()

        self.episode = 0

    def policy(self, observation):
        state = observation
        assert self.cat_policy is not None
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample().item()

        return action

    def fit(self, budget: int, **kwargs):
        del kwargs
        n_episodes_to_run = budget
        count = 0
        while count < n_episodes_to_run:
            self._run_episode()
            count += 1

    def _select_action(self, state):
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample()
        action_logprob = action_dist.log_prob(action)

        self.memory.states.append(state)
        self.memory.actions.append(action)
        self.memory.logprobs.append(action_logprob)

        return action.item()

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for _ in range(self.horizon):
            # running policy_old
            action = self._select_action(state)
            next_state, reward, done, info = self.env.step(action)

            # check whether to use bonus
            bonus = 0.0
            if self.use_bonus:
                if info is not None and "exploration_bonus" in info:
                    bonus = info["exploration_bonus"]

            # save in batch
            self.memory.rewards.append(reward + bonus)  # add bonus here
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # update
        self.episode += 1

        #
        if self.writer is not None:
            self.writer.add_scalar("episode_rewards", episode_rewards,
                                   self.episode)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # normalizing the rewards
        rewards = torch.tensor(rewards).to(self.device).float()
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.stack(self.memory.states).to(self.device).detach()
        old_actions = torch.stack(self.memory.actions).to(self.device).detach()
        old_logprobs = torch.stack(self.memory.logprobs).to(
            self.device).detach()

        # optimize policy for K epochs
        for _ in range(self.k_epochs):
            # evaluate old actions and values
            action_dist = self.cat_policy(old_states)
            logprobs = action_dist.log_prob(old_actions)
            state_values = torch.squeeze(self.value_net(old_states))
            dist_entropy = action_dist.entropy()

            # find ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # normalize the advantages
            advantages = rewards - state_values.detach()
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-8)
            # find surrogate loss
            surr1 = ratios * advantages
            surr2 = (
                torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) *
                advantages)
            loss = (-torch.min(surr1, surr2) +
                    self.avec_coef * self._avec_loss(state_values, rewards) +
                    self.vf_coef * self.MseLoss(state_values, rewards) -
                    self.entr_coef * dist_entropy)

            # take gradient step
            self.policy_optimizer.zero_grad()
            self.value_optimizer.zero_grad()

            loss.mean().backward()

            self.policy_optimizer.step()
            self.value_optimizer.step()

        # copy new weights into old policy
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

    def _avec_loss(self, y_pred, y_true):
        """
        Computes the objective function used in AVEC for the learning
        of the value function:
        the residual variance between the state-values and the
        empirical returns.

        Returns Var[y-ypred]
        :param y_pred: (np.ndarray) the prediction
        :param y_true: (np.ndarray) the expected value
        :return: (float) residual variance of ypred and y
        """
        assert y_true.ndim == 1 and y_pred.ndim == 1

        return torch.var(y_true - y_pred)

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical("batch_size", [1, 4, 8, 16, 32])
        gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.99])
        learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)

        entr_coef = trial.suggest_loguniform("entr_coef", 1e-8, 0.1)

        eps_clip = trial.suggest_categorical("eps_clip", [0.1, 0.2, 0.3])

        k_epochs = trial.suggest_categorical("k_epochs", [1, 5, 10, 20])

        return {
            "batch_size": batch_size,
            "gamma": gamma,
            "learning_rate": learning_rate,
            "entr_coef": entr_coef,
            "eps_clip": eps_clip,
            "k_epochs": k_epochs,
        }
Esempio n. 7
0
class A2CAgent(IncrementalAgent):
    """
    Parameters
    ----------
    env : Model
        Online model with continuous (Box) state space and discrete actions
    n_episodes : int
        Number of episodes
    batch_size : int
        Number of episodes to wait before updating the policy.
    horizon : int
        Horizon.
    gamma : double
        Discount factor in [0, 1].
    entr_coef : double
        Entropy coefficient.
    learning_rate : double
        Learning rate.
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    k_epochs : int
        Number of epochs per update.
    policy_net_fn : function(env, **kwargs)
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    value_net_fn : function(env, **kwargs)
        Function that returns an instance of a value network (pytorch).
        If None, a default net is used.
    policy_net_kwargs : dict
        kwargs for policy_net_fn
    value_net_kwargs : dict
        kwargs for value_net_fn
    use_bonus : bool, default = False
        If true, compute an 'exploration_bonus' and add it to the reward.
        See also UncertaintyEstimatorWrapper.
    uncertainty_estimator_kwargs : dict
        Arguments for the UncertaintyEstimatorWrapper
    device : str
        Device to put the tensors on

    References
    ----------
    Mnih, V., Badia, A.P., Mirza, M., Graves, A., Lillicrap, T., Harley, T.,
    Silver, D. & Kavukcuoglu, K. (2016).
    "Asynchronous methods for deep reinforcement learning."
    In International Conference on Machine Learning (pp. 1928-1937).
    """

    name = "A2C"

    def __init__(self,
                 env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 learning_rate=0.01,
                 optimizer_type='ADAM',
                 k_epochs=5,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 device="cuda:best",
                 **kwargs):
        self.use_bonus = use_bonus
        if self.use_bonus:
            env = UncertaintyEstimatorWrapper(env,
                                              **uncertainty_estimator_kwargs)
        IncrementalAgent.__init__(self, env, **kwargs)

        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.learning_rate = learning_rate
        self.k_epochs = k_epochs
        self.device = choose_device(device)

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()

    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)

        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()

        self.memory = Memory()

        self.episode = 0

        # useful data
        self._rewards = np.zeros(self.n_episodes)
        self._cumul_rewards = np.zeros(self.n_episodes)

        # default writer
        log_every = 5 * logger.getEffectiveLevel()
        self.writer = PeriodicWriter(self.name, log_every=log_every)

    def policy(self, state, **kwargs):
        assert self.cat_policy is not None
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample().item()
        return action

    def partial_fit(self, fraction: float, **kwargs):
        assert 0.0 < fraction <= 1.0
        n_episodes_to_run = int(np.ceil(fraction * self.n_episodes))
        count = 0
        while count < n_episodes_to_run and self.episode < self.n_episodes:
            self._run_episode()
            count += 1

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info

    def _select_action(self, state):
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample()
        action_logprob = action_dist.log_prob(action)

        self.memory.states.append(state)
        self.memory.actions.append(action)
        self.memory.logprobs.append(action_logprob)

        return action.item()

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for _ in range(self.horizon):
            # running policy_old
            action = self._select_action(state)
            next_state, reward, done, info = self.env.step(action)

            # check whether to use bonus
            bonus = 0.0
            if self.use_bonus:
                if info is not None and 'exploration_bonus' in info:
                    bonus = info['exploration_bonus']

            # save in batch
            self.memory.rewards.append(reward + bonus)  # add bonus here
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # update
        ep = self.episode
        self._rewards[ep] = episode_rewards
        self._cumul_rewards[ep] = episode_rewards \
            + self._cumul_rewards[max(0, ep - 1)]
        self.episode += 1
        #
        if self.writer is not None:
            self.writer.add_scalar("episode", self.episode, None)
            self.writer.add_scalar("ep reward", episode_rewards)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # normalize the rewards
        rewards = torch.tensor(rewards).to(self.device).float()
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.stack(self.memory.states).to(self.device).detach()
        old_actions = torch.stack(self.memory.actions).to(self.device).detach()

        # optimize policy for K epochs
        for _ in range(self.k_epochs):
            # evaluate old actions and values
            action_dist = self.cat_policy(old_states)
            logprobs = action_dist.log_prob(old_actions)
            state_values = torch.squeeze(self.value_net(old_states))
            dist_entropy = action_dist.entropy()

            # normalize the advantages
            advantages = rewards - state_values.detach()
            advantages = (advantages - advantages.mean()) \
                / (advantages.std() + 1e-8)
            # find pg loss
            pg_loss = -logprobs * advantages
            loss = pg_loss \
                + 0.5 * self.MseLoss(state_values, rewards) \
                - self.entr_coef * dist_entropy

            # take gradient step
            self.policy_optimizer.zero_grad()
            self.value_optimizer.zero_grad()

            loss.mean().backward()

            self.policy_optimizer.step()
            self.value_optimizer.step()

        # copy new weights into old policy
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical('batch_size', [1, 4, 8, 16, 32])
        gamma = trial.suggest_categorical('gamma', [0.9, 0.95, 0.99])
        learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1)

        entr_coef = trial.suggest_loguniform('entr_coef', 1e-8, 0.1)

        k_epochs = trial.suggest_categorical('k_epochs', [1, 5, 10, 20])

        return {
            'batch_size': batch_size,
            'gamma': gamma,
            'learning_rate': learning_rate,
            'entr_coef': entr_coef,
            'k_epochs': k_epochs,
        }
Esempio n. 8
0
class PPOAgent(AgentWithSimplePolicy):
    """
    Proximal Policy Optimization Agent.

    Policy gradient methods for reinforcement learning, which alternate between
    sampling data through interaction with the environment, and optimizing a
    “surrogate” objective function using stochastic gradient ascent

    Parameters
    ----------
    env : Model
        Online model with continuous (Box) state space and discrete actions
    batch_size : int
        Number of *episodes* to wait before updating the policy.
    horizon : int
        Horizon.
    gamma : double
        Discount factor in [0, 1].
    entr_coef : double
        Entropy coefficient.
    vf_coef : double
        Value function loss coefficient.
    learning_rate : double
        Learning rate.
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    eps_clip : double
        PPO clipping range (epsilon).
    k_epochs : int
        Number of epochs per update.
    policy_net_fn : function(env, **kwargs)
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    value_net_fn : function(env, **kwargs)
        Function that returns an instance of a value network (pytorch).
        If None, a default net is used.
    policy_net_kwargs : dict
        kwargs for policy_net_fn
    value_net_kwargs : dict
        kwargs for value_net_fn
    device: str
        Device to put the tensors on
    use_bonus : bool, default = False
        If true, compute the environment 'exploration_bonus'
        and add it to the reward. See also UncertaintyEstimatorWrapper.
    uncertainty_estimator_kwargs : dict
        kwargs for UncertaintyEstimatorWrapper

    References
    ----------
    Schulman, J., Wolski, F., Dhariwal, P., Radford, A. & Klimov, O. (2017).
    "Proximal Policy Optimization Algorithms."
    arXiv preprint arXiv:1707.06347.

    Schulman, J., Levine, S., Abbeel, P., Jordan, M., & Moritz, P. (2015).
    "Trust region policy optimization."
    In International Conference on Machine Learning (pp. 1889-1897).
    """

    name = "PPO"

    def __init__(self,
                 env,
                 batch_size=64,
                 update_frequency=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.5,
                 learning_rate=0.01,
                 optimizer_type="ADAM",
                 eps_clip=0.2,
                 k_epochs=5,
                 use_gae=True,
                 gae_lambda=0.95,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 device="cuda:best",
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 **kwargs):  # TODO: sort arguments

        # For all parameters, define self.param = param
        _, _, _, values = inspect.getargvalues(inspect.currentframe())
        values.pop("self")
        for arg, val in values.items():
            setattr(self, arg, val)
        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        # bonus
        self.use_bonus = use_bonus
        if self.use_bonus:
            self.env = UncertaintyEstimatorWrapper(
                self.env, **uncertainty_estimator_kwargs)

        # algorithm parameters

        # options
        # TODO: add reward normalization option
        #       add observation normalization option
        #       add orthogonal weight initialization option
        #       add value function clip option
        #       add ... ?
        self.normalize_advantages = True  # TODO: turn into argument

        self.use_gae = use_gae
        self.gae_lambda = gae_lambda

        # function approximators
        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.device = choose_device(device)

        self.optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()

    @classmethod
    def from_config(cls, **kwargs):
        kwargs["policy_net_fn"] = eval(kwargs["policy_net_fn"])
        kwargs["value_net_fn"] = eval(kwargs["value_net_fn"])
        return cls(**kwargs)

    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)
        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()  # TODO: turn into argument

        self.memory = Memory(
        )  # TODO: Improve memory to include returns and advantages
        self.returns = []  # TODO: add to memory
        self.advantages = []  # TODO: add to memory

        self.episode = 0

    def policy(self, observation):
        state = observation
        assert self.cat_policy is not None
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample().item()
        return action

    def fit(self, budget: int, **kwargs):
        del kwargs
        n_episodes_to_run = budget
        count = 0
        while count < n_episodes_to_run:
            self._run_episode()
            count += 1

    def _run_episode(self):
        # to store transitions
        states = []
        actions = []
        action_logprobs = []
        rewards = []
        is_terminals = []

        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()

        for _ in range(self.horizon):
            # running policy_old
            state = torch.from_numpy(state).float().to(self.device)

            action_dist = self.cat_policy_old(state)
            action = action_dist.sample()
            action_logprob = action_dist.log_prob(action)
            action = action

            next_state, reward, done, info = self.env.step(action.item())

            # check whether to use bonus
            bonus = 0.0
            if self.use_bonus:
                if info is not None and "exploration_bonus" in info:
                    bonus = info["exploration_bonus"]

            # save transition
            states.append(state)
            actions.append(action)
            action_logprobs.append(action_logprob)
            rewards.append(reward + bonus)  # bonus added here
            is_terminals.append(done)

            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # compute returns and advantages
        state_values = self.value_net(torch.stack(states).to(
            self.device)).detach()
        state_values = torch.squeeze(state_values).tolist()

        # TODO: add the option to normalize before computing returns/advantages?
        returns, advantages = self._compute_returns_avantages(
            rewards, is_terminals, state_values)

        # save in batch
        self.memory.states.extend(states)
        self.memory.actions.extend(actions)
        self.memory.logprobs.extend(action_logprobs)
        self.memory.rewards.extend(rewards)
        self.memory.is_terminals.extend(is_terminals)

        self.returns.extend(returns)  # TODO: add to memory (cf reset)
        self.advantages.extend(advantages)  # TODO: add to memory (cf reset)

        # increment ep counter
        self.episode += 1

        # log
        if self.writer is not None:
            self.writer.add_scalar("episode_rewards", episode_rewards,
                                   self.episode)

        # update
        if (self.episode % self.update_frequency == 0
            ):  # TODO: maybe change to update in function of n_steps instead
            self._update()
            self.memory.clear_memory()
            del self.returns[:]  # TODO: add to memory (cf reset)
            del self.advantages[:]  # TODO: add to memory (cf reset)

        return episode_rewards

    def _update(self):

        # convert list to tensor
        full_old_states = torch.stack(self.memory.states).to(
            self.device).detach()
        full_old_actions = torch.stack(self.memory.actions).to(
            self.device).detach()
        full_old_logprobs = torch.stack(self.memory.logprobs).to(
            self.device).detach()
        full_old_returns = torch.stack(self.returns).to(self.device).detach()
        full_old_advantages = torch.stack(self.advantages).to(
            self.device).detach()

        # optimize policy for K epochs
        n_samples = full_old_actions.size(0)
        n_batches = n_samples // self.batch_size

        for _ in range(self.k_epochs):

            # shuffle samples
            rd_indices = self.rng.choice(n_samples,
                                         size=n_samples,
                                         replace=False)
            shuffled_states = full_old_states[rd_indices]
            shuffled_actions = full_old_actions[rd_indices]
            shuffled_logprobs = full_old_logprobs[rd_indices]
            shuffled_returns = full_old_returns[rd_indices]
            shuffled_advantages = full_old_advantages[rd_indices]

            for k in range(n_batches):

                # sample batch
                batch_idx = np.arange(
                    k * self.batch_size,
                    min((k + 1) * self.batch_size, n_samples))
                old_states = shuffled_states[batch_idx]
                old_actions = shuffled_actions[batch_idx]
                old_logprobs = shuffled_logprobs[batch_idx]
                old_returns = shuffled_returns[batch_idx]
                old_advantages = shuffled_advantages[batch_idx]

                # evaluate old actions and values
                action_dist = self.cat_policy(old_states)
                logprobs = action_dist.log_prob(old_actions)
                state_values = torch.squeeze(self.value_net(old_states))
                dist_entropy = action_dist.entropy()

                # find ratio (pi_theta / pi_theta__old)
                ratios = torch.exp(logprobs - old_logprobs)

                # TODO: add this option
                # normalizing the rewards
                # rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

                # normalize the advantages
                old_advantages = old_advantages.view(-1, )

                if self.normalize_advantages:
                    old_advantages = (old_advantages - old_advantages.mean()
                                      ) / (old_advantages.std() + 1e-10)

                # compute surrogate loss
                surr1 = ratios * old_advantages
                surr2 = (
                    torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) *
                    old_advantages)
                surr_loss = torch.min(surr1, surr2)

                # compute value function loss
                loss_vf = self.vf_coef * self.MseLoss(state_values,
                                                      old_returns)

                # compute entropy loss
                loss_entropy = self.entr_coef * dist_entropy

                # compute total loss
                loss = -surr_loss + loss_vf - loss_entropy

                # take gradient step
                self.policy_optimizer.zero_grad()
                self.value_optimizer.zero_grad()

                loss.mean().backward()

                self.policy_optimizer.step()
                self.value_optimizer.step()

        # log
        if self.writer:
            self.writer.add_scalar(
                "fit/surrogate_loss",
                surr_loss.mean().cpu().detach().numpy(),
                self.episode,
            )
            self.writer.add_scalar(
                "fit/entropy_loss",
                dist_entropy.mean().cpu().detach().numpy(),
                self.episode,
            )

        # copy new weights into old policy
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

    def _compute_returns_avantages(self, rewards, is_terminals, state_values):

        returns = torch.zeros(self.horizon).to(self.device)
        advantages = torch.zeros(self.horizon).to(self.device)

        if not self.use_gae:
            for t in reversed(range(self.horizon)):
                if t == self.horizon - 1:
                    returns[t] = (rewards[t] + self.gamma *
                                  (1 - is_terminals[t]) * state_values[-1])
                else:
                    returns[t] = (rewards[t] + self.gamma *
                                  (1 - is_terminals[t]) * returns[t + 1])

                advantages[t] = returns[t] - state_values[t]
        else:
            last_adv = 0
            for t in reversed(range(self.horizon)):
                if t == self.horizon - 1:
                    returns[t] = (rewards[t] + self.gamma *
                                  (1 - is_terminals[t]) * state_values[-1])
                    td_error = returns[t] - state_values[t]
                else:
                    returns[t] = (rewards[t] + self.gamma *
                                  (1 - is_terminals[t]) * returns[t + 1])
                    td_error = (rewards[t] + self.gamma *
                                (1 - is_terminals[t]) * state_values[t + 1] -
                                state_values[t])

                last_adv = (self.gae_lambda * self.gamma *
                            (1 - is_terminals[t]) * last_adv + td_error)
                advantages[t] = last_adv

        return returns, advantages

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical("batch_size", [1, 4, 8, 16, 32])
        gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.99])
        learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)

        entr_coef = trial.suggest_loguniform("entr_coef", 1e-8, 0.1)

        eps_clip = trial.suggest_categorical("eps_clip", [0.1, 0.2, 0.3])

        k_epochs = trial.suggest_categorical("k_epochs", [1, 5, 10, 20])

        return {
            "batch_size": batch_size,
            "gamma": gamma,
            "learning_rate": learning_rate,
            "entr_coef": entr_coef,
            "eps_clip": eps_clip,
            "k_epochs": k_epochs,
        }
Esempio n. 9
0
class REINFORCEAgent(AgentWithSimplePolicy):
    """
    REINFORCE with entropy regularization.

    Parameters
    ----------
    env : Model
        Online model with continuous (Box) state space and discrete actions
    batch_size : int
        Number of episodes to wait before updating the policy.
    horizon : int
        Horizon.
    gamma : double
        Discount factor in [0, 1].
    entr_coef : double
        Entropy coefficient.
    learning_rate : double
        Learning rate.
    normalize: bool
        If True normalize rewards
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    policy_net_fn : function(env, **kwargs)
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    policy_net_kwargs : dict
        kwargs for policy_net_fn
    use_bonus_if_available : bool, default = False
        If true, check if environment info has entry 'exploration_bonus'
        and add it to the reward. See also UncertaintyEstimatorWrapper.
    device: str
        Device to put the tensors on

    References
    ----------
    Williams, Ronald J.,
    "Simple statistical gradient-following algorithms for connectionist
    reinforcement learning."
    ReinforcementLearning.Springer,Boston,MA,1992.5-3
    """

    name = "REINFORCE"

    def __init__(self,
                 env,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 learning_rate=0.0001,
                 normalize=True,
                 optimizer_type="ADAM",
                 policy_net_fn=None,
                 policy_net_kwargs=None,
                 use_bonus_if_available=False,
                 device="cuda:best",
                 **kwargs):

        # For all parameters, define self.param = param
        _, _, _, values = inspect.getargvalues(inspect.currentframe())
        values.pop("self")
        for arg, val in values.items():
            setattr(self, arg, val)

        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.device = choose_device(device)

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        self.policy_net_kwargs = policy_net_kwargs or {}

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn

        self.optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.policy_net = None  # policy network

        # initialize
        self.reset()

    def reset(self, **kwargs):
        self.policy_net = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)

        self.policy_optimizer = optimizer_factory(self.policy_net.parameters(),
                                                  **self.optimizer_kwargs)

        self.memory = Memory()

        self.episode = 0

    def policy(self, observation):
        state = observation
        assert self.policy_net is not None
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.policy_net(state)
        action = action_dist.sample().item()
        return action

    def fit(self, budget: int, **kwargs):
        del kwargs
        n_episodes_to_run = budget
        count = 0
        while count < n_episodes_to_run:
            self._run_episode()
            count += 1

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for _ in range(self.horizon):
            # running policy
            action = self.policy(state)
            next_state, reward, done, info = self.env.step(action)

            # check whether to use bonus
            bonus = 0.0
            if self.use_bonus_if_available:
                if info is not None and "exploration_bonus" in info:
                    bonus = info["exploration_bonus"]

            # save in batch
            self.memory.states.append(state)
            self.memory.actions.append(action)
            self.memory.rewards.append(reward + bonus)  # add bonus here
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # update
        self.episode += 1

        #
        if self.writer is not None:
            self.writer.add_scalar("episode_rewards", episode_rewards,
                                   self.episode)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _normalize(self, x):
        return (x - x.mean()) / (x.std() + 1e-5)

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # convert list to tensor
        states = torch.FloatTensor(self.memory.states).to(self.device)
        actions = torch.LongTensor(self.memory.actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        if self.normalize:
            rewards = self._normalize(rewards)

        # evaluate logprobs
        action_dist = self.policy_net(states)
        logprobs = action_dist.log_prob(actions)
        dist_entropy = action_dist.entropy()

        # compute loss
        loss = -logprobs * rewards - self.entr_coef * dist_entropy

        # take gradient step
        self.policy_optimizer.zero_grad()

        loss.mean().backward()

        self.policy_optimizer.step()

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical("batch_size", [1, 4, 8, 16, 32])
        gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.99])
        learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)

        entr_coef = trial.suggest_loguniform("entr_coef", 1e-8, 0.1)

        return {
            "batch_size": batch_size,
            "gamma": gamma,
            "learning_rate": learning_rate,
            "entr_coef": entr_coef,
        }
Esempio n. 10
0
class A2CAgent(AgentWithSimplePolicy):
    """
    Advantage Actor Critic Agent.

    A2C, or Advantage Actor Critic, is a synchronous version of the A3C policy
    gradient method. As an alternative to the asynchronous implementation of
    A3C, A2C is a synchronous, deterministic implementation that waits for each
    actor to finish its segment of experience before updating, averaging over
    all of the actors. This more effectively uses GPUs due to larger batch sizes.

    Parameters
    ----------
    env : Model
        Online model with continuous (Box) state space and discrete actions
    batch_size : int
        Number of episodes to wait before updating the policy.
    horizon : int
        Horizon.
    gamma : double
        Discount factor in [0, 1].
    entr_coef : double
        Entropy coefficient.
    learning_rate : double
        Learning rate.
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    k_epochs : int
        Number of epochs per update.
    policy_net_fn : function(env, **kwargs)
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    value_net_fn : function(env, **kwargs)
        Function that returns an instance of a value network (pytorch).
        If None, a default net is used.
    policy_net_kwargs : dict
        kwargs for policy_net_fn
    value_net_kwargs : dict
        kwargs for value_net_fn
    use_bonus : bool, default = False
        If true, compute an 'exploration_bonus' and add it to the reward.
        See also UncertaintyEstimatorWrapper.
    uncertainty_estimator_kwargs : dict
        Arguments for the UncertaintyEstimatorWrapper
    device : str
        Device to put the tensors on

    References
    ----------
    Mnih, V., Badia, A.P., Mirza, M., Graves, A., Lillicrap, T., Harley, T.,
    Silver, D. & Kavukcuoglu, K. (2016).
    "Asynchronous methods for deep reinforcement learning."
    In International Conference on Machine Learning (pp. 1928-1937).
    """

    name = "A2C"

    def __init__(self,
                 env,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 learning_rate=0.01,
                 optimizer_type="ADAM",
                 k_epochs=5,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 device="cuda:best",
                 **kwargs):

        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.use_bonus = use_bonus
        if self.use_bonus:
            self.env = UncertaintyEstimatorWrapper(
                self.env, **uncertainty_estimator_kwargs)

        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.learning_rate = learning_rate
        self.k_epochs = k_epochs
        self.device = choose_device(device)

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()

    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)

        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()

        self.memory = Memory()

        self.episode = 0

    def policy(self, observation):
        state = observation
        assert self.cat_policy is not None
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample().item()
        return action

    def fit(self, budget: int, **kwargs):
        del kwargs
        n_episodes_to_run = budget
        count = 0
        while count < n_episodes_to_run:
            self._run_episode()
            count += 1

    def _select_action(self, state):
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample()
        action_logprob = action_dist.log_prob(action)

        self.memory.states.append(state)
        self.memory.actions.append(action)
        self.memory.logprobs.append(action_logprob)

        return action.item()

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for i in range(self.horizon):
            # running policy_old
            action = self._select_action(state)
            next_state, reward, done, info = self.env.step(action)

            # check whether to use bonus
            bonus = 0.0
            if self.use_bonus:
                if info is not None and "exploration_bonus" in info:
                    bonus = info["exploration_bonus"]

            # save in batch
            self.memory.rewards.append(reward + bonus)  # add bonus here
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break
            # update state
            state = next_state

            if i == self.horizon - 1:
                self.memory.is_terminals[-1] = True

        # update
        self.episode += 1
        #
        if self.writer is not None:
            self.writer.add_scalar("episode_rewards", episode_rewards,
                                   self.episode)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # normalize the rewards
        rewards = torch.tensor(rewards).to(self.device).float()
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.stack(self.memory.states).to(self.device).detach()
        old_actions = torch.stack(self.memory.actions).to(self.device).detach()

        # optimize policy for K epochs
        for _ in range(self.k_epochs):
            # evaluate old actions and values
            action_dist = self.cat_policy(old_states)
            logprobs = action_dist.log_prob(old_actions)
            state_values = torch.squeeze(self.value_net(old_states))
            dist_entropy = action_dist.entropy()

            # normalize the advantages
            advantages = rewards - state_values.detach()
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-8)
            # find pg loss
            pg_loss = -logprobs * advantages
            loss = (pg_loss + 0.5 * self.MseLoss(state_values, rewards) -
                    self.entr_coef * dist_entropy)

            # take gradient step
            self.policy_optimizer.zero_grad()
            self.value_optimizer.zero_grad()

            loss.mean().backward()

            self.policy_optimizer.step()
            self.value_optimizer.step()

        # copy new weights into old policy
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical("batch_size", [1, 4, 8, 16, 32])
        gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.99])
        learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)

        entr_coef = trial.suggest_loguniform("entr_coef", 1e-8, 0.1)

        k_epochs = trial.suggest_categorical("k_epochs", [1, 5, 10, 20])

        return {
            "batch_size": batch_size,
            "gamma": gamma,
            "learning_rate": learning_rate,
            "entr_coef": entr_coef,
            "k_epochs": k_epochs,
        }
Esempio n. 11
0
class PPOAgent(IncrementalAgent):
    """
    Parameters
    ----------
    env : Model
        Online model with continuous (Box) state space and discrete actions
    n_episodes : int
        Number of episodes
    batch_size : int
        Number of *episodes* to wait before updating the policy.
    horizon : int
        Horizon.
    gamma : double
        Discount factor in [0, 1].
    entr_coef : double
        Entropy coefficient.
    vf_coef : double
        Value function loss coefficient.
    learning_rate : double
        Learning rate.
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    eps_clip : double
        PPO clipping range (epsilon).
    k_epochs : int
        Number of epochs per update.
    policy_net_fn : function(env, \*\*kwargs)
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    value_net_fn : function(env, \*\*kwargs)
        Function that returns an instance of a value network (pytorch).
        If None, a default net is used.
    policy_net_kwargs : dict
        kwargs for policy_net_fn
    value_net_kwargs : dict
        kwargs for value_net_fn
    device: str
        Device to put the tensors on
    use_bonus : bool, default = False
        If true, compute the environment 'exploration_bonus'
        and add it to the reward. See also UncertaintyEstimatorWrapper.
    uncertainty_estimator_kwargs : dict
        kwargs for UncertaintyEstimatorWrapper

    References
    ----------
    Schulman, J., Wolski, F., Dhariwal, P., Radford, A. & Klimov, O. (2017).
    "Proximal Policy Optimization Algorithms."
    arXiv preprint arXiv:1707.06347.

    Schulman, J., Levine, S., Abbeel, P., Jordan, M., & Moritz, P. (2015).
    "Trust region policy optimization."
    In International Conference on Machine Learning (pp. 1889-1897).
    """

    name = "PPO"

    def __init__(self,
                 env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.5,
                 learning_rate=0.01,
                 optimizer_type='ADAM',
                 eps_clip=0.2,
                 k_epochs=5,
                 use_gae=True,
                 gae_lambda=0.95,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 device="cuda:best",
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 **kwargs):
        self.use_bonus = use_bonus
        if self.use_bonus:
            env = UncertaintyEstimatorWrapper(env,
                                              **uncertainty_estimator_kwargs)
        IncrementalAgent.__init__(self, env, **kwargs)

        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.vf_coef = vf_coef
        self.learning_rate = learning_rate
        self.eps_clip = eps_clip
        self.k_epochs = k_epochs
        self.use_gae = use_gae
        self.gae_lambda = gae_lambda

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.device = choose_device(device)

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()

    @classmethod
    def from_config(cls, **kwargs):
        kwargs["policy_net_fn"] = eval(kwargs["policy_net_fn"])
        kwargs["value_net_fn"] = eval(kwargs["value_net_fn"])
        return cls(**kwargs)

    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)
        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()

        self.memory = Memory()

        self.episode = 0

        # useful data
        self._rewards = np.zeros(self.n_episodes)
        self._cumul_rewards = np.zeros(self.n_episodes)

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())

    def policy(self, state, **kwargs):
        assert self.cat_policy is not None
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample().item()
        return action

    def partial_fit(self, fraction: float, **kwargs):
        assert 0.0 < fraction <= 1.0
        n_episodes_to_run = int(np.ceil(fraction * self.n_episodes))
        count = 0
        while count < n_episodes_to_run and self.episode < self.n_episodes:
            self._run_episode()
            count += 1

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info

    def _select_action(self, state):
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample()
        action_logprob = action_dist.log_prob(action)

        self.memory.states.append(state)
        self.memory.actions.append(action)
        self.memory.logprobs.append(action_logprob)

        return action.item()

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for _ in range(self.horizon):
            # running policy_old
            action = self._select_action(state)
            next_state, reward, done, info = self.env.step(action)

            # check whether to use bonus
            bonus = 0.0
            if self.use_bonus:
                if info is not None and 'exploration_bonus' in info:
                    bonus = info['exploration_bonus']

            # save in batch
            self.memory.rewards.append(reward + bonus)  # bonus added here
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # update
        ep = self.episode
        self._rewards[ep] = episode_rewards
        self._cumul_rewards[ep] = episode_rewards \
            + self._cumul_rewards[max(0, ep - 1)]
        self.episode += 1

        #
        if self.writer is not None:
            self.writer.add_scalar("fit/total_reward", episode_rewards,
                                   self.episode)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # convert list to tensor
        old_states = torch.stack(self.memory.states).to(self.device).detach()
        old_actions = torch.stack(self.memory.actions).to(self.device).detach()
        old_logprobs = torch.stack(self.memory.logprobs).to(
            self.device).detach()

        # optimize policy for K epochs
        for _ in range(self.k_epochs):
            # evaluate old actions and values
            action_dist = self.cat_policy(old_states)
            logprobs = action_dist.log_prob(old_actions)
            state_values = torch.squeeze(self.value_net(old_states))
            dist_entropy = action_dist.entropy()

            # find ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            rewards = torch.tensor(rewards).to(self.device).float()
            returns = torch.zeros(rewards.shape).to(self.device)
            advantages = torch.zeros(rewards.shape).to(self.device)

            if not self.use_gae:
                for t in reversed(range(self.horizon)):
                    if t == self.horizon - 1:
                        returns[t] = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]) * state_values[-1]
                    else:
                        returns[t] = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]) * returns[t + 1]
                    advantages[t] = returns[t] - state_values[t]
            else:
                for t in reversed(range(self.horizon)):
                    if t == self.horizon - 1:
                        returns[t] = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]) * state_values[-1]
                        td_error = returns[t] - state_values[t]
                    else:
                        returns[t] = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]) * returns[t + 1]
                        td_error = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]
                        ) * state_values[t + 1] - state_values[t]
                    advantages[
                        t] = advantages[t] * self.gae_lambda * self.gamma * (
                            1 - self.memory.is_terminals[t]) + td_error

            # normalizing the rewards
            # rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

            # convert to pytorch tensors and move to gpu if available
            advantages = advantages.view(-1, )

            # normalize the advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-10)

            # find surrogate loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip,
                                1 + self.eps_clip) * advantages
            surr_loss = torch.min(surr1, surr2)
            loss = - surr_loss \
                + self.vf_coef * self.MseLoss(state_values, rewards) \
                - self.entr_coef * dist_entropy

            # take gradient step
            self.policy_optimizer.zero_grad()
            self.value_optimizer.zero_grad()

            loss.mean().backward()

            self.policy_optimizer.step()
            self.value_optimizer.step()

        self.writer.add_scalar("fit/surrogate_loss",
                               surr_loss.mean().cpu().detach().numpy(),
                               self.episode)
        self.writer.add_scalar("fit/entropy_loss",
                               dist_entropy.mean().cpu().detach().numpy(),
                               self.episode)

        # copy new weights into old policy
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical('batch_size', [1, 4, 8, 16, 32])
        gamma = trial.suggest_categorical('gamma', [0.9, 0.95, 0.99])
        learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1)

        entr_coef = trial.suggest_loguniform('entr_coef', 1e-8, 0.1)

        eps_clip = trial.suggest_categorical('eps_clip', [0.1, 0.2, 0.3])

        k_epochs = trial.suggest_categorical('k_epochs', [1, 5, 10, 20])

        return {
            'batch_size': batch_size,
            'gamma': gamma,
            'learning_rate': learning_rate,
            'entr_coef': entr_coef,
            'eps_clip': eps_clip,
            'k_epochs': k_epochs,
        }
Esempio n. 12
0
class PPOAgent(IncrementalAgent):
    """
    Parameters
    ----------
    env : Model
        Online model with continuous (Box) state space and discrete actions
    n_episodes : int
        Number of episodes
    batch_size : int
        Number of episodes to wait before updating the policy.
    horizon : int
        Horizon.
    gamma : double
        Discount factor in [0, 1].
    entr_coef : double
        Entropy coefficient.
    vf_coef : double
        Value function loss coefficient.
    learning_rate : double
        Learning rate.
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    eps_clip : double
        PPO clipping range (epsilon).
    k_epochs : int
        Number of epochs per update.
    policy_net_fn : function
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    value_net_fn : function
        Function that returns an instance of a value network (pytorch).
        If None, a default net is used.


    References
    ----------
    Schulman, J., Wolski, F., Dhariwal, P., Radford, A. & Klimov, O. (2017).
    "Proximal Policy Optimization Algorithms."
    arXiv preprint arXiv:1707.06347.

    Schulman, J., Levine, S., Abbeel, P., Jordan, M., & Moritz, P. (2015).
    "Trust region policy optimization."
    In International Conference on Machine Learning (pp. 1889-1897).
    """

    name = "PPO"
    fit_info = ("n_episodes", "episode_rewards")

    def __init__(self,
                 env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.5,
                 learning_rate=0.01,
                 optimizer_type='ADAM',
                 eps_clip=0.2,
                 k_epochs=5,
                 policy_net_fn=None,
                 value_net_fn=None,
                 **kwargs):
        IncrementalAgent.__init__(self, env, **kwargs)

        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.vf_coef = vf_coef
        self.learning_rate = learning_rate
        self.eps_clip = eps_clip
        self.k_epochs = k_epochs

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn \
            or (lambda: default_policy_net_fn(self.env))

        self.value_net_fn = value_net_fn \
            or (lambda: default_value_net_fn(self.env))

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()

    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn().to(device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn().to(device)
        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn().to(device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()

        self.memory = Memory()

        self.episode = 0

        # useful data
        self._rewards = np.zeros(self.n_episodes)
        self._cumul_rewards = np.zeros(self.n_episodes)

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())

    def policy(self, state, **kwargs):
        assert self.cat_policy is not None
        state = torch.from_numpy(state).float().to(device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample().item()
        return action

    def fit(self, **kwargs):
        for _ in range(self.n_episodes):
            self._run_episode()

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info

    def partial_fit(self, fraction: float, **kwargs):
        assert 0.0 < fraction <= 1.0
        n_episodes_to_run = int(np.ceil(fraction * self.n_episodes))
        count = 0
        while count < n_episodes_to_run and self.episode < self.n_episodes:
            self._run_episode()
            count += 1

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info

    def _select_action(self, state):
        state = torch.from_numpy(state).float().to(device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample()
        action_logprob = action_dist.log_prob(action)

        self.memory.states.append(state)
        self.memory.actions.append(action)
        self.memory.logprobs.append(action_logprob)

        return action.item()

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for _ in range(self.horizon):
            # running policy_old
            action = self._select_action(state)
            next_state, reward, done, _ = self.env.step(action)

            # save in batch
            self.memory.rewards.append(reward)
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # update
        ep = self.episode
        self._rewards[ep] = episode_rewards
        self._cumul_rewards[ep] = episode_rewards \
            + self._cumul_rewards[max(0, ep - 1)]
        self.episode += 1

        #
        if self.writer is not None:
            self.writer.add_scalar("episode", self.episode, None)
            self.writer.add_scalar("ep reward", episode_rewards)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # normalizing the rewards
        rewards = torch.tensor(rewards).to(device).float()
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.stack(self.memory.states).to(device).detach()
        old_actions = torch.stack(self.memory.actions).to(device).detach()
        old_logprobs = torch.stack(self.memory.logprobs).to(device).detach()

        # optimize policy for K epochs
        for _ in range(self.k_epochs):
            # evaluate old actions and values
            action_dist = self.cat_policy(old_states)
            logprobs = action_dist.log_prob(old_actions)
            state_values = self.value_net(old_states)
            dist_entropy = action_dist.entropy()

            # find ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # normalize the advantages
            advantages = rewards - state_values.detach()
            advantages = (advantages - advantages.mean()) / \
                (advantages.std() + 1e-8)
            # find surrogate loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip,
                                1 + self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2) \
                + self.vf_coef * self.MseLoss(state_values, rewards) \
                - self.entr_coef * dist_entropy

            # take gradient step
            self.policy_optimizer.zero_grad()
            self.value_optimizer.zero_grad()

            loss.mean().backward()

            self.policy_optimizer.step()
            self.value_optimizer.step()

        # copy new weights into old policy
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical('batch_size',
                                               [1, 4, 8, 16, 32, 64])
        learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1)
        return {
            'batch_size': batch_size,
            'learning_rate': learning_rate,
        }
Esempio n. 13
0
class TRPOAgent(IncrementalAgent):
    """
    Parameters
    ----------
    env : Model
        Online model with continuous (Box) state space and discrete actions
    n_episodes : int
        Number of episodes
    batch_size : int
        Number of *episodes* to wait before updating the policy.
    horizon : int
        Horizon.
    gamma : double
        Discount factor in [0, 1].
    entr_coef : double
        Entropy coefficient.
    vf_coef : double
        Value function loss coefficient.
    learning_rate : double
        Learning rate.
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    eps_clip : double
        PPO clipping range (epsilon).
    k_epochs : int
        Number of epochs per update.
    policy_net_fn : function(env, **kwargs)
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    value_net_fn : function(env, **kwargs)
        Function that returns an instance of a value network (pytorch).
        If None, a default net is used.
    policy_net_kwargs : dict
        kwargs for policy_net_fn
    value_net_kwargs : dict
        kwargs for value_net_fn
    device: str
        Device to put the tensors on
    use_bonus : bool, default = False
        If true, compute the environment 'exploration_bonus'
        and add it to the reward. See also UncertaintyEstimatorWrapper.
    uncertainty_estimator_kwargs : dict
        kwargs for UncertaintyEstimatorWrapper

    References
    ----------
    Schulman, J., Levine, S., Abbeel, P., Jordan, M., & Moritz, P. (2015).
    "Trust region policy optimization."
    In International Conference on Machine Learning (pp. 1889-1897).
    """

    name = "TRPO"

    def __init__(self,
                 env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.5,
                 learning_rate=0.01,
                 optimizer_type='ADAM',
                 k_epochs=5,
                 use_gae=True,
                 gae_lambda=0.95,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 device="cuda:best",
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 **kwargs):
        self.use_bonus = use_bonus
        if self.use_bonus:
            env = UncertaintyEstimatorWrapper(env,
                                              **uncertainty_estimator_kwargs)
        IncrementalAgent.__init__(self, env, **kwargs)

        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.vf_coef = vf_coef
        self.learning_rate = learning_rate
        self.k_epochs = k_epochs
        self.use_gae = use_gae
        self.gae_lambda = gae_lambda
        self.damping = 0  # TODO: turn into argument
        self.max_kl = 0.1  # TODO: turn into argument
        self.use_entropy = False  # TODO: test, and eventually turn into argument
        self.normalize_advantage = True  # TODO: turn into argument
        self.normalize_reward = False  # TODO: turn into argument

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.device = choose_device(device)

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        # TODO: check
        self.cat_policy = None  # categorical policy function
        self.policy_optimizer = None

        self.value_net = None
        self.value_optimizer = None

        self.cat_policy_old = None

        self.value_loss_fn = None

        self.memory = None

        self.episode = 0

        self._rewards = None
        self._cumul_rewards = None

        # initialize
        self.reset()

    @classmethod
    def from_config(cls, **kwargs):
        kwargs["policy_net_fn"] = eval(kwargs["policy_net_fn"])
        kwargs["value_net_fn"] = eval(kwargs["value_net_fn"])
        return cls(**kwargs)

    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn(
            self.env, **self.value_net_kwargs).to(self.device)
        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn(
            self.env, **self.policy_net_kwargs).to(self.device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.value_loss_fn = nn.MSELoss()  # TODO: turn into argument

        self.memory = Memory()

        self.episode = 0

        # useful data
        self._rewards = np.zeros(self.n_episodes)
        self._cumul_rewards = np.zeros(self.n_episodes)

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())

    def policy(self, state, **kwargs):
        assert self.cat_policy is not None
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample().item()
        return action

    def partial_fit(self, fraction: float, **kwargs):
        assert 0.0 < fraction <= 1.0
        n_episodes_to_run = int(np.ceil(fraction * self.n_episodes))
        count = 0
        while count < n_episodes_to_run and self.episode < self.n_episodes:
            self._run_episode()
            count += 1

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info

    def _select_action(self, state):
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample()
        action_logprob = action_dist.log_prob(action)

        return action, action_logprob

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for _ in range(self.horizon):
            # running policy_old
            action, log_prob = self._select_action(state)
            next_state, reward, done, info = self.env.step(action.item())

            # check whether to use bonus
            bonus = 0.0
            if self.use_bonus:
                if info is not None and 'exploration_bonus' in info:
                    bonus = info['exploration_bonus']

            # save in batch
            self.memory.states.append(
                torch.from_numpy(state).float().to(self.device))
            self.memory.actions.append(action)
            self.memory.logprobs.append(log_prob)
            self.memory.rewards.append(reward + bonus)  # bonus added here
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # update
        ep = self.episode
        self._rewards[ep] = episode_rewards
        self._cumul_rewards[ep] = episode_rewards + self._cumul_rewards[max(
            0, ep - 1)]
        self.episode += 1

        #
        if self.writer is not None:
            self.writer.add_scalar("fit/total_reward", episode_rewards,
                                   self.episode)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # convert list to tensor
        # TODO: shuffle samples for each epoch
        old_states = torch.stack(self.memory.states).to(self.device).detach()
        old_actions = torch.stack(self.memory.actions).to(self.device).detach()
        old_logprobs = torch.stack(self.memory.logprobs).to(
            self.device).detach()

        old_action_dist = self.cat_policy_old(old_states)

        # optimize policy for K epochs
        for _ in range(self.k_epochs):
            # evaluate old actions and values
            action_dist = self.cat_policy(old_states)
            logprobs = action_dist.log_prob(old_actions)
            state_values = torch.squeeze(self.value_net(old_states))
            dist_entropy = action_dist.entropy()

            # find ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # compute returns and advantages
            rewards = torch.tensor(rewards).to(self.device).float()
            returns = torch.zeros(rewards.shape).to(self.device)
            advantages = torch.zeros(rewards.shape).to(self.device)

            if not self.use_gae:
                for t in reversed(range(self.horizon)):
                    if t == self.horizon - 1:
                        returns[t] = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]) * state_values[-1]
                    else:
                        returns[t] = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]) * returns[t + 1]
                    advantages[t] = returns[t] - state_values[t]
            else:
                for t in reversed(range(self.horizon)):
                    if t == self.horizon - 1:
                        returns[t] = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]) * state_values[-1]
                        td_error = returns[t] - state_values[t]
                    else:
                        returns[t] = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]) * returns[t + 1]
                        td_error = rewards[t] + self.gamma * (
                            1 - self.memory.is_terminals[t]
                        ) * state_values[t + 1] - state_values[t]
                    advantages[
                        t] = advantages[t] * self.gae_lambda * self.gamma * (
                            1 - self.memory.is_terminals[t]) + td_error

            # normalizing the rewards
            if self.normalize_reward:
                rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

            # convert to pytorch tensors and move to gpu if available
            advantages = advantages.view(-1, )

            # normalize the advantages
            if self.normalize_advantage:
                advantages = (advantages -
                              advantages.mean()) / (advantages.std() + 1e-10)

            # estimate policy gradient
            loss = -ratios * advantages

            if self.use_entropy:
                loss += -self.entr_coef * dist_entropy

            # TODO: Check gradient's sign, conjugate_gradients function, fisher_vp function, linesearch function
            # TODO: Check the gradients and if they flow correctly
            grads = torch.autograd.grad(loss.mean(),
                                        self.cat_policy.parameters(),
                                        retain_graph=True)
            loss_grad = torch.cat([grad.view(-1) for grad in grads]).data

            # conjugate gradient algorithm
            step_dir = self.conjugate_gradients(-loss_grad,
                                                old_action_dist,
                                                old_states,
                                                nsteps=10)

            # update the policy by backtracking line search
            shs = 0.5 * (step_dir * self.fisher_vp(
                step_dir, old_action_dist, old_states)).sum(0, keepdim=True)

            lagrange_mult = torch.sqrt(shs / self.max_kl).item()
            full_step = step_dir / lagrange_mult

            neggdotstepdir = (-loss_grad * step_dir).sum(0, keepdim=True)
            # print(f'Lagrange multiplier: {lm[0]}, grad norm: {loss_grad.norm()}')

            prev_params = self.get_flat_params_from(self.cat_policy)
            success, new_params = self.linesearch(
                old_states, old_actions, old_logprobs, advantages, prev_params,
                full_step, neggdotstepdir / lagrange_mult)

            # fit value function by regression
            value_loss = self.vf_coef * self.value_loss_fn(
                state_values, rewards)

            self.value_optimizer.zero_grad()
            value_loss.mean().backward()
            self.value_optimizer.step()

        # log
        self.writer.add_scalar("fit/value_loss",
                               value_loss.mean().cpu().detach().numpy(),
                               self.episode)
        self.writer.add_scalar("fit/entropy_loss",
                               dist_entropy.mean().cpu().detach().numpy(),
                               self.episode)

        # copy new weights into old policy
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

    def conjugate_gradients(self,
                            b,
                            old_action_dist,
                            old_states,
                            nsteps,
                            residual_tol=1e-10):
        x = torch.zeros(b.size())
        r = b.clone()
        p = b.clone()
        rdotr = torch.dot(r, r)
        for i in range(nsteps):
            _Avp = self.fisher_vp(p, old_action_dist, old_states)
            alpha = rdotr / torch.dot(p, _Avp)
            x += alpha * p
            r -= alpha * _Avp
            new_rdotr = torch.dot(r, r)
            betta = new_rdotr / rdotr
            p = r + betta * p
            rdotr = new_rdotr
            if rdotr < residual_tol:
                break
        return x

    def fisher_vp(self, v, old_action_dist, old_states):

        action_dist = self.cat_policy(old_states)
        kl = kl_divergence(old_action_dist, action_dist)
        kl = kl.mean()

        grads = torch.autograd.grad(kl,
                                    self.cat_policy.parameters(),
                                    create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        kl_v = (flat_grad_kl * v).sum()
        grads = torch.autograd.grad(kl_v,
                                    self.cat_policy.parameters(),
                                    allow_unused=True)
        flat_grad_grad_kl = torch.cat(
            [grad.contiguous().view(-1) for grad in grads]).data

        return flat_grad_grad_kl + v * self.damping

    def linesearch(self,
                   old_states,
                   old_actions,
                   old_logprobs,
                   advantages,
                   params,
                   fullstep,
                   expected_improve_rate,
                   max_backtracks=10,
                   accept_ratio=.1):

        with torch.no_grad():
            action_dist = self.cat_policy(old_states)
            logprobs = action_dist.log_prob(old_actions)
            ratios = torch.exp(logprobs - old_logprobs.detach())
            loss = (ratios * advantages).data

        for stepfrac in .5**np.arange(max_backtracks):
            new_params = params + stepfrac * fullstep
            self.set_flat_params_to(self.cat_policy, new_params)

            with torch.no_grad():
                action_dist = self.cat_policy(old_states)
                logprobs = action_dist.log_prob(old_actions)
                ratios = torch.exp(logprobs - old_logprobs.detach())
                new_loss = (ratios * advantages).data

            actual_improve = (loss - new_loss).mean()
            expected_improve = expected_improve_rate * stepfrac
            ratio = actual_improve / expected_improve
            # print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())

            if ratio.item() > accept_ratio and actual_improve.item() > 0:
                # print("fval after", newfval.item())
                return True, new_params
        return False, params

    def get_flat_params_from(self, model):
        params = []
        for param in model.parameters():
            params.append(param.data.view(-1))

        flat_params = torch.cat(params)
        return flat_params

    def set_flat_params_to(self, model, flat_params):
        prev_ind = 0
        for param in model.parameters():
            flat_size = int(np.prod(list(param.size())))
            param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(
                param.size()))
            prev_ind += flat_size

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical('batch_size', [1, 4, 8, 16, 32])
        gamma = trial.suggest_categorical('gamma', [0.9, 0.95, 0.99])
        learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1)

        entr_coef = trial.suggest_loguniform('entr_coef', 1e-8, 0.1)

        eps_clip = trial.suggest_categorical('eps_clip', [0.1, 0.2, 0.3])

        k_epochs = trial.suggest_categorical('k_epochs', [1, 5, 10, 20])

        return {
            'batch_size': batch_size,
            'gamma': gamma,
            'learning_rate': learning_rate,
            'entr_coef': entr_coef,
            'eps_clip': eps_clip,
            'k_epochs': k_epochs,
        }
Esempio n. 14
0
class REINFORCEAgent(IncrementalAgent):
    """
    REINFORCE with entropy regularization.

    Parameters
    ----------
    env : Model
        Online model with continuous (Box) state space and discrete actions
    n_episodes : int
        Number of episodes
    batch_size : int
        Number of episodes to wait before updating the policy.
    horizon : int
        Horizon.
    gamma : double
        Discount factor in [0, 1].
    entr_coef : double
        Entropy coefficient.
    learning_rate : double
        Learning rate.
    normalize: bool
        If True normalize rewards
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    policy_net_fn : function(env, **kwargs)
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    policy_net_kwargs : dict
        kwargs for policy_net_fn
    use_bonus_if_available : bool, default = False
        If true, check if environment info has entry 'exploration_bonus'
        and add it to the reward. See also UncertaintyEstimatorWrapper.
    device: str
        Device to put the tensors on

    References
    ----------
    Williams, Ronald J.,
    "Simple statistical gradient-following algorithms for connectionist
    reinforcement learning."
    ReinforcementLearning.Springer,Boston,MA,1992.5-3
    """

    name = "REINFORCE"

    def __init__(self,
                 env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 learning_rate=0.0001,
                 normalize=True,
                 optimizer_type='ADAM',
                 policy_net_fn=None,
                 policy_net_kwargs=None,
                 use_bonus_if_available=False,
                 device="cuda:best",
                 **kwargs):
        IncrementalAgent.__init__(self, env, **kwargs)

        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.learning_rate = learning_rate
        self.normalize = normalize
        self.use_bonus_if_available = use_bonus_if_available
        self.device = choose_device(device)

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        self.policy_net_kwargs = policy_net_kwargs or {}

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.policy_net = None  # policy network

        # initialize
        self.reset()

    def reset(self, **kwargs):
        self.policy_net = self.policy_net_fn(
            self.env,
            **self.policy_net_kwargs,
        ).to(self.device)

        self.policy_optimizer = optimizer_factory(self.policy_net.parameters(),
                                                  **self.optimizer_kwargs)

        self.memory = Memory()

        self.episode = 0

        # useful data
        self._rewards = np.zeros(self.n_episodes)
        self._cumul_rewards = np.zeros(self.n_episodes)

        # default writer
        log_every = 5 * logger.getEffectiveLevel()
        self.writer = PeriodicWriter(self.name, log_every=log_every)

    def policy(self, state, **kwargs):
        assert self.policy_net is not None
        state = torch.from_numpy(state).float().to(self.device)
        action_dist = self.policy_net(state)
        action = action_dist.sample().item()
        return action

    def partial_fit(self, fraction: float, **kwargs):
        assert 0.0 < fraction <= 1.0
        n_episodes_to_run = int(np.ceil(fraction * self.n_episodes))
        count = 0
        while count < n_episodes_to_run and self.episode < self.n_episodes:
            self._run_episode()
            count += 1

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for _ in range(self.horizon):
            # running policy
            action = self.policy(state)
            next_state, reward, done, info = self.env.step(action)

            # check whether to use bonus
            bonus = 0.0
            if self.use_bonus_if_available:
                if info is not None and 'exploration_bonus' in info:
                    bonus = info['exploration_bonus']

            # save in batch
            self.memory.states.append(state)
            self.memory.actions.append(action)
            self.memory.rewards.append(reward + bonus)  # add bonus here
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # update
        ep = self.episode
        self._rewards[ep] = episode_rewards
        self._cumul_rewards[ep] = episode_rewards \
            + self._cumul_rewards[max(0, ep - 1)]
        self.episode += 1

        #
        if self.writer is not None:
            self.writer.add_scalar("episode", self.episode, None)
            self.writer.add_scalar("ep reward", episode_rewards)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _normalize(self, x):
        return (x - x.mean()) / (x.std() + 1e-5)

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # convert list to tensor
        states = torch.FloatTensor(self.memory.states).to(self.device)
        actions = torch.LongTensor(self.memory.actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        if self.normalize:
            rewards = self._normalize(rewards)

        # evaluate logprobs
        action_dist = self.policy_net(states)
        logprobs = action_dist.log_prob(actions)
        dist_entropy = action_dist.entropy()

        # compute loss
        loss = -logprobs * rewards - self.entr_coef * dist_entropy

        # take gradient step
        self.policy_optimizer.zero_grad()

        loss.mean().backward()

        self.policy_optimizer.step()

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical('batch_size', [1, 4, 8, 16, 32])
        gamma = trial.suggest_categorical('gamma', [0.9, 0.95, 0.99])
        learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1)

        entr_coef = trial.suggest_loguniform('entr_coef', 1e-8, 0.1)

        return {
            'batch_size': batch_size,
            'gamma': gamma,
            'learning_rate': learning_rate,
            'entr_coef': entr_coef,
        }
Esempio n. 15
0
class AVECPPOAgent(IncrementalAgent):
    """
    AVEC uses a modification of the training objective for the critic in
    actor-critic algorithms to better approximate the value function (critic).
    The new state-value function approximation learns the *relative* value of
    the states rather than their *absolute* value as in conventional
    actor-critic. This modification is:
    - well-motivated by recent studies [1,2];
    - theoretically sound;
    - intuitively supported by the need to improve the approximation error
    of the critic.

    The application of Actor with Variance Estimated Critic (AVEC) to
    state-of-the-art policy gradient methods produces considerable
    gains in performance (on average +26% for SAC and +40% for PPO)
    over the standard actor-critic training.

    Parameters
    ----------
    env : Model
        model with continuous (Box) state space and discrete actions
    n_episodes : int
        Number of episodes
    batch_size : int
        Number of episodes to wait before updating the policy.
    horizon : int
        Horizon of the objective function. If None and gamma<1,
        set to 1/(1-gamma).
    gamma : double
        Discount factor in [0, 1]. If gamma is 1.0, the problem is set
        to be finite-horizon.
    entr_coef : double
        Entropy coefficient.
    vf_coef : double
        Value function loss coefficient.
    learning_rate : double
        Learning rate.
    optimizer_type: str
        Type of optimizer. 'ADAM' by defaut.
    eps_clip : double
        PPO clipping range (epsilon).
    k_epochs : int
        Number of epochs per update.
    policy_net_fn : function
        Function that returns an instance of a policy network (pytorch).
        If None, a default net is used.
    value_net_fn : function
        Function that returns an instance of a value network (pytorch).
        If None, a default net is used.


    References
    ----------
    Flet-Berliac, Y., Ouhamma, R., Maillard, O. A., & Preux, P. (2020).
    "Is Standard Deviation the New Standard? Revisiting the Critic in Deep
    Policy Gradients."
    arXiv preprint arXiv:2010.04440.

    [1] Ilyas, A., Engstrom, L., Santurkar, S., Tsipras, D., Janoos, F.,
    Rudolph, L. & Madry, A. (2020).
    "A closer look at deep policy gradients."
    In International Conference on Learning Representations.

    [2] Tucker, G., Bhupatiraju, S., Gu, S., Turner, R., Ghahramani, Z. &
    Levine, S. (2018).
    "The mirage of action-dependent baselines in reinforcement learning."
    In International Conference on Machine Learning, pp. 5015–5024.
    """

    name = "AVECPPO"
    fit_info = ("n_episodes", "episode_rewards")

    def __init__(self,
                 env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.,
                 avec_coef=1.,
                 learning_rate=0.0003,
                 optimizer_type='ADAM',
                 eps_clip=0.2,
                 k_epochs=10,
                 policy_net_fn=None,
                 value_net_fn=None,
                 **kwargs):
        IncrementalAgent.__init__(self, env, **kwargs)

        self.learning_rate = learning_rate
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.vf_coef = vf_coef
        self.avec_coef = avec_coef
        self.eps_clip = eps_clip
        self.k_epochs = k_epochs
        self.horizon = horizon
        self.n_episodes = n_episodes
        self.batch_size = batch_size

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn \
            or (lambda: default_policy_net_fn(self.env))

        self.value_net_fn = value_net_fn \
            or (lambda: default_value_net_fn(self.env))

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()

    def reset(self, **kwargs):
        self.cat_policy = self.policy_net_fn().to(device)
        self.policy_optimizer = optimizer_factory(self.cat_policy.parameters(),
                                                  **self.optimizer_kwargs)

        self.value_net = self.value_net_fn().to(device)
        self.value_optimizer = optimizer_factory(self.value_net.parameters(),
                                                 **self.optimizer_kwargs)

        self.cat_policy_old = self.policy_net_fn().to(device)
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

        self.MseLoss = nn.MSELoss()

        self.memory = Memory()

        self.episode = 0

        # useful data
        self._rewards = np.zeros(self.n_episodes)
        self._cumul_rewards = np.zeros(self.n_episodes)

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())

    def policy(self, state, **kwargs):
        assert self.cat_policy is not None
        state = torch.from_numpy(state).float().to(device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample().item()

        return action

    def fit(self, **kwargs):
        for _ in range(self.n_episodes):
            self._run_episode()

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info

    def partial_fit(self, fraction: float, **kwargs):
        assert 0.0 < fraction <= 1.0
        n_episodes_to_run = int(np.ceil(fraction * self.n_episodes))
        count = 0
        while count < n_episodes_to_run and self.episode < self.n_episodes:
            self._run_episode()
            count += 1

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info

    def _select_action(self, state):
        state = torch.from_numpy(state).float().to(device)
        action_dist = self.cat_policy_old(state)
        action = action_dist.sample()
        action_logprob = action_dist.log_prob(action)

        self.memory.states.append(state)
        self.memory.actions.append(action)
        self.memory.logprobs.append(action_logprob)

        return action.item()

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for _ in range(self.horizon):
            # running policy_old
            action = self._select_action(state)
            next_state, reward, done, _ = self.env.step(action)

            # save in batch
            self.memory.rewards.append(reward)
            self.memory.is_terminals.append(done)
            episode_rewards += reward

            if done:
                break

            # update state
            state = next_state

        # update
        ep = self.episode
        self._rewards[ep] = episode_rewards
        self._cumul_rewards[ep] = episode_rewards \
            + self._cumul_rewards[max(0, ep - 1)]
        self.episode += 1

        #
        if self.writer is not None:
            self.writer.add_scalar("episode", self.episode, None)
            self.writer.add_scalar("ep reward", episode_rewards)

        #
        if self.episode % self.batch_size == 0:
            self._update()
            self.memory.clear_memory()

        return episode_rewards

    def _update(self):
        # monte carlo estimate of rewards
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards),
                                       reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # normalizing the rewards
        rewards = torch.tensor(rewards).to(device).float()
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.stack(self.memory.states).to(device).detach()
        old_actions = torch.stack(self.memory.actions).to(device).detach()
        old_logprobs = torch.stack(self.memory.logprobs).to(device).detach()

        # optimize policy for K epochs
        for _ in range(self.k_epochs):
            # evaluate old actions and values
            action_dist = self.cat_policy(old_states)
            logprobs = action_dist.log_prob(old_actions)
            state_values = self.value_net(old_states)
            dist_entropy = action_dist.entropy()

            # find ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # normalize the advantages
            advantages = rewards - state_values.detach()
            advantages = (advantages - advantages.mean()) / \
                         (advantages.std() + 1e-8)
            # find surrogate loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip,
                                1 + self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2) \
                + self.avec_coef * self._avec_loss(state_values, rewards) \
                + self.vf_coef * self.MseLoss(state_values, rewards) \
                - self.entr_coef * dist_entropy

            # take gradient step
            self.policy_optimizer.zero_grad()
            self.value_optimizer.zero_grad()

            loss.mean().backward()

            self.policy_optimizer.step()
            self.value_optimizer.step()

        # copy new weights into old policy
        self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())

    def _avec_loss(self, y_pred, y_true):
        """
        Computes the objective function used in AVEC for the learning
        of the value function:
        the residual variance between the state-values and the
        empirical returns.

        Returns Var[y-ypred]
        :param y_pred: (np.ndarray) the prediction
        :param y_true: (np.ndarray) the expected value
        :return: (float) residual variance of ypred and y
        """
        assert y_true.ndim == 1 and y_pred.ndim == 1

        return torch.var(y_true - y_pred)

    #
    # For hyperparameter optimization
    #
    @classmethod
    def sample_parameters(cls, trial):
        batch_size = trial.suggest_categorical('batch_size',
                                               [1, 4, 8, 16, 32, 64])
        learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1)
        return {
            'batch_size': batch_size,
            'learning_rate': learning_rate,
        }