Exemplo n.º 1
0
    def __init__(self,
                 *,
                 env,
                 learning_rate=3e-4,
                 buffer_size=128,
                 batch_size=64,
                 n_epochs=10,
                 gamma=0.99,
                 gae_lam=0.95,
                 clip_range=0.1,
                 ent_coef=.01,
                 vf_coef=1.0,
                 max_grad_norm=0.5):
        super(PPO, self).__init__(env, learning_rate, buffer_size, batch_size,
                                  n_epochs, gamma, gae_lam, clip_range,
                                  ent_coef, vf_coef, max_grad_norm)

        self.policy = Policy(env=env, device=self.device)
        self.rollout = RolloutStorage(buffer_size,
                                      self.num_envs,
                                      env.observation_space,
                                      env.action_space,
                                      gae_lam=gae_lam)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)

        self.last_obs = self.env.reset()
Exemplo n.º 2
0
    def __init__(self,
                 *,
                 env_id,
                 lr=3e-4,
                 int_lr=3e-4,
                 nstep=128,
                 batch_size=128,
                 n_epochs=10,
                 gamma=0.99,
                 gae_lam=0.95,
                 clip_range=0.2,
                 ent_coef=.01,
                 vf_coef=0.5,
                 max_grad_norm=0.2,
                 hidden_size=128,
                 int_hidden_size=32,
                 int_rew_integration=0.05,
                 beta=0.2,
                 policy_weight=1):
        super(PPO_ICM, self).__init__(env_id, lr, nstep, batch_size, n_epochs,
                                      gamma, gae_lam, clip_range, ent_coef,
                                      vf_coef, max_grad_norm)

        self.int_rew_integration = int_rew_integration

        self.policy = Policy(self.env, hidden_size)
        self.rollout = RolloutStorage(nstep,
                                      self.num_envs,
                                      self.env.observation_space,
                                      self.env.action_space,
                                      gae_lam=gae_lam)
        self.intrinsic_module = IntrinsicCuriosityModule(
            self.state_dim, self.action_converter, hidden_size=int_hidden_size)

        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.icm_optimizer = optim.Adam(self.intrinsic_module.parameters(),
                                        lr=int_lr)

        self.last_obs = self.env.reset()

        self.policy_weight = policy_weight
        self.beta = 0.2
Exemplo n.º 3
0
    def __init__(self, config: Config):
        self.config = config
        self.is_training = True
        self.buffer = RolloutStorage(config)
        if self.config.Dueling_DQN:
            self.model = Dueling_DQN(self.config.state_shape,
                                     self.config.action_dim)
            self.target_model = Dueling_DQN(self.config.state_shape,
                                            self.config.action_dim)
        else:
            self.model = CnnDQN(self.config.state_shape,
                                self.config.action_dim)
            self.target_model = CnnDQN(self.config.state_shape,
                                       self.config.action_dim)
        self.target_model.load_state_dict(self.model.state_dict())
        self.model_optim = torch.optim.Adam(self.model.parameters(),
                                            lr=self.config.learning_rate)

        if self.config.use_cuda:
            self.cuda()
Exemplo n.º 4
0
    def __init__(self,
                 *,
                 env_id,
                 lr=3e-4,
                 nstep=128,
                 batch_size=128,
                 n_epochs=10,
                 gamma=0.99,
                 gae_lam=0.95,
                 clip_range=0.2,
                 ent_coef=.01,
                 vf_coef=1,
                 max_grad_norm=0.2,
                 hidden_size=128,
                 sim_hash=False,
                 sil=False):
        super(PPO, self).__init__(env_id, lr, nstep, batch_size, n_epochs,
                                  gamma, gae_lam, clip_range, ent_coef,
                                  vf_coef, max_grad_norm)

        self.policy = Policy(self.env, hidden_size)
        self.rollout = RolloutStorage(nstep,
                                      self.num_envs,
                                      self.env.observation_space,
                                      self.env.action_space,
                                      gae_lam=gae_lam,
                                      gamma=gamma,
                                      sim_hash=sim_hash)
        self.optimizer = optim.Adam(self.policy.net.parameters(), lr=lr)

        self.last_obs = self.env.reset()
        self.sim_hash = sim_hash
        self.sil = sil
        if sil:
            self.sil_module = SilModule(50000, self.policy, self.optimizer,
                                        self.num_envs, self.env)
Exemplo n.º 5
0
class CnnDDQNAgent:
    def __init__(self, config: Config):
        self.config = config
        self.is_training = True
        self.buffer = RolloutStorage(config)
        if self.config.Dueling_DQN:
            self.model = Dueling_DQN(self.config.state_shape,
                                     self.config.action_dim)
            self.target_model = Dueling_DQN(self.config.state_shape,
                                            self.config.action_dim)
        else:
            self.model = CnnDQN(self.config.state_shape,
                                self.config.action_dim)
            self.target_model = CnnDQN(self.config.state_shape,
                                       self.config.action_dim)
        self.target_model.load_state_dict(self.model.state_dict())
        self.model_optim = torch.optim.Adam(self.model.parameters(),
                                            lr=self.config.learning_rate)

        if self.config.use_cuda:
            self.cuda()

    def act(self, state, epsilon=None):
        if epsilon is None:
            epsilon = self.config.epsilon_min
        if random.random() > epsilon or not self.is_training:
            state = torch.tensor(state, dtype=torch.float) / 255.0
            if self.config.use_cuda:
                state = state.to(self.config.device)
            q_value = self.model.forward(state)
            action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.config.action_dim)
        return action

    def learning(self, fr):
        s0, s1, a, r, done = self.buffer.sample(self.config.batch_size)
        if self.config.use_cuda:
            s0 = s0.float().to(self.config.device) / 255.0
            s1 = s1.float().to(self.config.device) / 255.0
            a = a.to(self.config.device)
            r = r.to(self.config.device)
            done = done.to(self.config.device)

        # How to calculate Q(s,a) for all actions
        # q_values is a vector with size (batch_size, action_shape, 1)
        # each dimension i represents Q(s0,a_i)
        q_s0_values = self.model(s0).cuda()

        # How to calculate argmax_a Q(s,a)
        # actions = q_values.max(1)[1]
        q_s0_a = torch.gather(q_s0_values, 1, a)

        # Tips: function torch.gather may be helpful
        # You need to design how to calculate the loss
        if self.config.DQN:
            q_target_s1_values = self.target_model(s1).cuda().detach()
            q_target_s1_a_prime = q_target_s1_values.max(1)[0].unsqueeze(1)
            # if current state is end of episode, then there  is no next Q value
            q_target_s1_a_prime = torch.mul(q_target_s1_a_prime, (1 - done))
            y = r + self.config.gamma * q_target_s1_a_prime
        elif self.config.Double_DQN:
            q_s1_values = self.model(s1).cuda().detach()
            s1_a_prime = q_s1_values.max(1)[1].unsqueeze(1)
            q_target_s1_values = self.target_model(s1).cuda().detach()
            q_target_s1_a_prime = torch.gather(q_target_s1_values, 1,
                                               s1_a_prime)
            q_target_s1_a_prime = torch.mul(q_target_s1_a_prime, (1 - done))
            y = r + self.config.gamma * q_target_s1_a_prime
        else:
            pass
        mse_loss = torch.nn.MSELoss()
        loss = mse_loss(q_s0_a, y)
        self.model_optim.zero_grad()
        loss.backward()
        self.model_optim.step()

        if fr % self.config.update_tar_interval == 0:
            self.target_model.load_state_dict(self.model.state_dict())
        return loss.item()

    def cuda(self):
        self.model.to(self.config.device)
        self.target_model.to(self.config.device)

    def load_weights(self, model_path):
        model = torch.load(model_path)
        if 'model' in model:
            self.model.load_state_dict(model['model'])
        else:
            self.model.load_state_dict(model)

    def save_model(self, output, name=''):
        torch.save(self.model.state_dict(), '%s/model_%s.pkl' % (output, name))

    def save_config(self, output):
        with open(output + '/config.txt', 'w') as f:
            attr_val = get_class_attr_val(self.config)
            for k, v in attr_val.items():
                f.write(str(k) + " = " + str(v) + "\n")

    def save_checkpoint(self, fr, output):
        checkpath = output + '/checkpoint_model'
        os.makedirs(checkpath, exist_ok=True)
        torch.save({
            'frames': fr,
            'model': self.model.state_dict()
        }, '%s/checkpoint_fr_%d.tar' % (checkpath, fr))

    def load_checkpoint(self, model_path):
        checkpoint = torch.load(model_path)
        fr = checkpoint['frames']
        self.model.load_state_dict(checkpoint['model'])
        self.target_model.load_state_dict(checkpoint['model'])
        return fr
Exemplo n.º 6
0
class PPO(BaseAlgorithm):
    def __init__(self,
                 *,
                 env,
                 learning_rate=3e-4,
                 buffer_size=128,
                 batch_size=64,
                 n_epochs=10,
                 gamma=0.99,
                 gae_lam=0.95,
                 clip_range=0.1,
                 ent_coef=.01,
                 vf_coef=1.0,
                 max_grad_norm=0.5):
        super(PPO, self).__init__(env, learning_rate, buffer_size, batch_size,
                                  n_epochs, gamma, gae_lam, clip_range,
                                  ent_coef, vf_coef, max_grad_norm)

        self.policy = Policy(env=env, device=self.device)
        self.rollout = RolloutStorage(buffer_size,
                                      self.num_envs,
                                      env.observation_space,
                                      env.action_space,
                                      gae_lam=gae_lam)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)

        self.last_obs = self.env.reset()

    def collect_samples(self):

        assert self.last_obs is not None

        rollout_step = 0
        self.rollout.reset()

        while rollout_step < self.buffer_size:
            with torch.no_grad():
                # Convert to pytorch tensor
                actions, values, log_probs = self.policy.act(self.last_obs)

            actions = actions.cpu()
            obs, rewards, dones, infos = self.env.step(actions)
            rollout_step += 1
            self.num_timesteps += self.num_envs
            self.update_info_buffer(infos)
            self.rollout.add(self.last_obs, actions, rewards, values, dones,
                             log_probs)
            self.last_obs = obs

        self.rollout.compute_returns_and_advantages(values, dones=dones)

        return True

    def train(self):

        total_losses, policy_losses, value_losses, entropy_losses = [], [], [], []

        for epoch in range(self.n_epochs):
            for batch in self.rollout.get(self.batch_size):
                actions = batch.actions.long().flatten()
                old_log_probs = batch.old_log_probs.to(self.device)
                advantages = batch.advantages.to(self.device)
                returns = batch.returns.to(self.device)

                state_values, action_log_probs, entropy = self.policy.evaluate(
                    batch.observations, actions)
                state_values = state_values.squeeze()

                advantages = (advantages -
                              advantages.mean()) / (advantages.std() + 1e-8)

                ratio = torch.exp(action_log_probs - old_log_probs)

                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * torch.clamp(
                    ratio, 1 - self.clip_range, 1 + self.clip_range)
                policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

                value_loss = F.mse_loss(returns, state_values)

                if entropy is None:
                    entropy_loss = -action_log_probs.mean()
                else:
                    entropy_loss = -torch.mean(entropy)

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(),
                                               self.max_grad_norm)
                self.optimizer.step()

                total_losses.append(loss.item())
                policy_losses.append(policy_loss.item())
                value_losses.append(value_loss.item())
                entropy_losses.append(entropy_loss.item())

        logger.record("train/entropy_loss", np.mean(entropy_losses))
        logger.record("train/policy_gradient_loss", np.mean(policy_losses))
        logger.record("train/value_loss", np.mean(value_losses))
        logger.record("train/total_loss", np.mean(total_losses))

        self._n_updates += self.n_epochs

    def learn(self, total_timesteps, log_interval):
        start_time = time.time()
        iteration = 0

        while self.num_timesteps < total_timesteps:
            progress = round(self.num_timesteps / total_timesteps * 100, 2)
            self.collect_samples()

            iteration += 1
            if log_interval is not None and iteration % log_interval == 0:
                logger.record("Progress", str(progress) + '%')
                logger.record("time/total timesteps", self.num_timesteps)
                if len(self.ep_info_buffer) > 0 and len(
                        self.ep_info_buffer[0]) > 0:
                    logger.record(
                        "rollout/ep_rew_mean",
                        np.mean(
                            [ep_info["r"] for ep_info in self.ep_info_buffer]))
                    logger.record(
                        "rollout/ep_len_mean",
                        np.mean(
                            [ep_info["l"] for ep_info in self.ep_info_buffer]))
                fps = int(self.num_timesteps / (time.time() - start_time))
                logger.record("time/total_time", (time.time() - start_time))
                logger.dump(step=self.num_timesteps)

            self.train()

        logger.record("Complete", '.')
        logger.record("time/total timesteps", self.num_timesteps)
        if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
            logger.record(
                "rollout/ep_rew_mean",
                np.mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
            logger.record(
                "rollout/ep_len_mean",
                np.mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
        fps = int(self.num_timesteps / (time.time() - start_time))
        logger.record("time/total_time", (time.time() - start_time))
        logger.dump(step=self.num_timesteps)

        return self
Exemplo n.º 7
0
class PPO_ICM(BaseAlgorithm):
    """
    Base algorithm class that each agent has to inherit from.
    :param env_id: (str)                    name of environment to perform training on
    :param lr: (float)                      learning rate
    :param int_lr: (float)                  intrinsic learning rate
    :param nstep: (int)                     storage rollout steps
    :param batch_size: (int)                batch size for training
    :param n_epochs: (int)                  number of training epochs
    :param gamma: (float)                   discount factor
    :param gae_lam: (float)                 lambda for generalized advantage estimation
    :param clip_range: (float)              clip range for surrogate loss
    :param ent_coef: (float)                entropy loss coefficient
    :param vf_coef: (float)                 value loss coefficient
    :param max_grad_norm: (float)           max grad norm for optimizer
    :param hidden_size: (int)               size of the hidden layers of policy
    :param int_hidden_size: (int)           size of the hidden layers for the RND target and predictor networks
    :param int_rew_integration: (float):    weighting of extrinsic vs intrinsic rewards
    :param beta: (float)                    beta parameter weighing forward vs inverse dynamic models
    :param policy_weight: (float)           parameter specifying weight of policy vs intrinsic loss
    """
    def __init__(self,
                 *,
                 env_id,
                 lr=3e-4,
                 int_lr=3e-4,
                 nstep=128,
                 batch_size=128,
                 n_epochs=10,
                 gamma=0.99,
                 gae_lam=0.95,
                 clip_range=0.2,
                 ent_coef=.01,
                 vf_coef=0.5,
                 max_grad_norm=0.2,
                 hidden_size=128,
                 int_hidden_size=32,
                 int_rew_integration=0.05,
                 beta=0.2,
                 policy_weight=1):
        super(PPO_ICM, self).__init__(env_id, lr, nstep, batch_size, n_epochs,
                                      gamma, gae_lam, clip_range, ent_coef,
                                      vf_coef, max_grad_norm)

        self.int_rew_integration = int_rew_integration

        self.policy = Policy(self.env, hidden_size)
        self.rollout = RolloutStorage(nstep,
                                      self.num_envs,
                                      self.env.observation_space,
                                      self.env.action_space,
                                      gae_lam=gae_lam)
        self.intrinsic_module = IntrinsicCuriosityModule(
            self.state_dim, self.action_converter, hidden_size=int_hidden_size)

        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.icm_optimizer = optim.Adam(self.intrinsic_module.parameters(),
                                        lr=int_lr)

        self.last_obs = self.env.reset()

        self.policy_weight = policy_weight
        self.beta = 0.2

    def collect_samples(self):
        """
        Collect one full rollout, as determined by the nstep parameter, and add it to the buffer
        """
        assert self.last_obs is not None

        rollout_step = 0
        self.rollout.reset()

        # For logging
        test_int_rewards = []

        while rollout_step < self.nstep:

            with torch.no_grad():
                # Convert to pytorch tensor
                actions, values, log_probs = self.policy.act(self.last_obs)

            obs, rewards, dones, infos = self.env.step(actions.numpy())

            if any(dones):
                self.num_episodes += sum(dones)
            rollout_step += 1
            self.num_timesteps += self.num_envs
            self.update_info_buffer(infos)

            int_rewards = self.intrinsic_module.int_reward(
                torch.Tensor(self.last_obs), torch.Tensor(obs), actions)
            rewards = (
                1 - self.int_rew_integration
            ) * rewards + self.int_rew_integration * int_rewards.detach(
            ).numpy()

            # For logging
            test_int_rewards.append(int_rewards.mean().item())

            actions = actions.reshape(self.num_envs,
                                      self.action_converter.action_output)
            log_probs = log_probs.reshape(self.num_envs,
                                          self.action_converter.action_output)

            self.rollout.add(self.last_obs, actions, rewards, values, dones,
                             log_probs)

            self.last_obs = obs
        logger.record("rollout/mean_int_reward",
                      np.round(np.mean(np.array(test_int_rewards)), 10))
        self.rollout.compute_returns_and_advantages(values, dones=dones)

        return True

    def train(self):
        """
        Use the collected data from the buffer to train the policy network
        """
        total_losses, policy_losses, value_losses, entropy_losses, icm_losses = [], [], [], [], []

        inv_criterion = self.action_converter.get_loss()

        for epoch in range(self.n_epochs):
            for batch in self.rollout.get(self.batch_size):
                observations = batch.observations
                actions = batch.actions
                old_log_probs = batch.old_log_probs
                old_values = batch.old_values
                advantages = batch.advantages
                returns = batch.returns

                state_values, action_log_probs, entropy = self.policy.evaluate(
                    observations, actions)

                advantages = (advantages -
                              advantages.mean()) / (advantages.std() + 1e-8)
                ratio = torch.exp(action_log_probs - old_log_probs)

                # Surrogate loss
                surr_loss_1 = advantages * ratio
                surr_loss_2 = advantages * torch.clamp(
                    ratio, 1 - self.clip_range, 1 + self.clip_range)
                policy_loss = -torch.min(surr_loss_1, surr_loss_2).mean()

                # Clipped value loss
                state_values_clipped = old_values + (
                    state_values - old_values).clamp(-self.clip_range,
                                                     self.clip_range)
                value_loss = F.mse_loss(returns, state_values).mean()
                value_loss_clipped = F.mse_loss(returns,
                                                state_values_clipped).mean()
                value_loss = torch.max(value_loss, value_loss_clipped).mean()
                # Icm loss
                actions_hat, next_features, next_features_hat = self.intrinsic_module(
                    observations[:-1], observations[1:], actions[:-1])

                forward_loss = F.mse_loss(next_features, next_features_hat)
                inverse_loss = inv_criterion(
                    actions_hat, self.action_converter.action(actions[:-1]))
                icm_loss = (
                    1 - self.beta) * inverse_loss + self.beta * forward_loss

                entropy_loss = -torch.mean(entropy)

                loss = self.policy_weight * (
                    policy_loss + self.vf_coef * value_loss +
                    self.ent_coef * entropy_loss) + icm_loss

                self.optimizer.zero_grad()
                self.icm_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy.net.parameters(),
                                               self.max_grad_norm)
                self.optimizer.step()
                self.icm_optimizer.step()

                total_losses.append(loss.item())
                policy_losses.append(policy_loss.item())
                value_losses.append(value_loss.item())
                entropy_losses.append(entropy_loss.item())
                icm_losses.append(icm_loss.item())

        logger.record("train/entropy_loss", np.mean(entropy_losses))
        logger.record("train/policy_gradient_loss", np.mean(policy_losses))
        logger.record("train/value_loss", np.mean(value_losses))
        logger.record("train/total_loss", np.mean(total_losses))
        logger.record("train/icm_loss", np.mean(icm_losses))

        self._n_updates += self.n_epochs

    def learn(self,
              total_timesteps,
              log_interval=5,
              reward_target=None,
              log_to_file=False):
        """
        Initiate the training of the algorithm.

        :param total_timesteps: (int)   total number of timesteps the agent is to run for
        :param log_interval: (int)      how often to perform logging
        :param reward_target: (int)     reaching the reward target stops training early
        :param log_to_file: (bool)      specify whether output ought to be logged
        """
        logger.configure("ICM", self.env_id, log_to_file)
        start_time = time.time()
        iteration = 0

        while self.num_timesteps < total_timesteps:
            progress = round(self.num_timesteps / total_timesteps * 100, 2)
            self.collect_samples()

            iteration += 1
            if log_interval is not None and iteration % log_interval == 0:
                logger.record("Progress", str(progress) + '%')
                logger.record("time/total timesteps", self.num_timesteps)
                if len(self.ep_info_buffer) > 0 and len(
                        self.ep_info_buffer[0]) > 0:
                    logger.record(
                        "rollout/ep_rew_mean",
                        np.mean(
                            [ep_info["r"] for ep_info in self.ep_info_buffer]))
                    logger.record("rollout/num_episodes", self.num_episodes)
                fps = int(self.num_timesteps / (time.time() - start_time))
                logger.record("time/total_time", (time.time() - start_time))
                logger.dump(step=self.num_timesteps)

            self.train()

            if reward_target is not None and np.mean(
                [ep_info["r"]
                 for ep_info in self.ep_info_buffer]) > reward_target:
                logger.record("time/total timesteps", self.num_timesteps)
                if len(self.ep_info_buffer) > 0 and len(
                        self.ep_info_buffer[0]) > 0:
                    logger.record(
                        "rollout/ep_rew_mean",
                        np.mean(
                            [ep_info["r"] for ep_info in self.ep_info_buffer]))
                    logger.record("rollout/num_episodes", self.num_episodes)
                fps = int(self.num_timesteps / (time.time() - start_time))
                logger.record("time/total_time", (time.time() - start_time))
                logger.dump(step=self.num_timesteps)
                break
        return self
Exemplo n.º 8
0
class PPO(BaseAlgorithm):
    """
    Base algorithm class that each agent has to inherit from.
    :param env_id: (str)            name of environment to perform training on
    :param lr: (float)              learning rate
    :param nstep: (int)             storage rollout steps
    :param batch_size: (int)        batch size for training
    :param n_epochs: (int)          number of training epochs
    :param gamma: (float)           discount factor
    :param gae_lam: (float)         lambda for generalized advantage estimation
    :param clip_range: (float)      clip range for surrogate loss
    :param ent_coef: (float)        entropy loss coefficient
    :param vf_coef: (float)         value loss coefficient
    :param max_grad_norm: (float)   max grad norm for optimizer
    :param hidden_size: (int)       size of the hidden layers of policy
    :param sim_hash: (bool)         sim hash switch
    :param sil: (bool)              self imitation learning switch
    """
    def __init__(self,
                 *,
                 env_id,
                 lr=3e-4,
                 nstep=128,
                 batch_size=128,
                 n_epochs=10,
                 gamma=0.99,
                 gae_lam=0.95,
                 clip_range=0.2,
                 ent_coef=.01,
                 vf_coef=1,
                 max_grad_norm=0.2,
                 hidden_size=128,
                 sim_hash=False,
                 sil=False):
        super(PPO, self).__init__(env_id, lr, nstep, batch_size, n_epochs,
                                  gamma, gae_lam, clip_range, ent_coef,
                                  vf_coef, max_grad_norm)

        self.policy = Policy(self.env, hidden_size)
        self.rollout = RolloutStorage(nstep,
                                      self.num_envs,
                                      self.env.observation_space,
                                      self.env.action_space,
                                      gae_lam=gae_lam,
                                      gamma=gamma,
                                      sim_hash=sim_hash)
        self.optimizer = optim.Adam(self.policy.net.parameters(), lr=lr)

        self.last_obs = self.env.reset()
        self.sim_hash = sim_hash
        self.sil = sil
        if sil:
            self.sil_module = SilModule(50000, self.policy, self.optimizer,
                                        self.num_envs, self.env)

    def collect_samples(self):
        """
        Collect one full rollout, as determined by the nstep parameter, and add it to the buffer
        """
        assert self.last_obs is not None
        rollout_step = 0
        self.rollout.reset()

        while rollout_step < self.nstep:
            with torch.no_grad():
                actions, values, log_probs = self.policy.act(self.last_obs)

            actions = actions.numpy()
            obs, rewards, dones, infos = self.env.step(actions)
            if any(dones):
                self.num_episodes += sum(dones)

            self.num_timesteps += self.num_envs
            self.update_info_buffer(infos)

            actions = actions.reshape(self.num_envs,
                                      self.action_converter.action_output)
            log_probs = log_probs.reshape(self.num_envs,
                                          self.action_converter.action_output)

            if self.sil:
                self.sil_module.step(self.last_obs, actions, log_probs,
                                     rewards, dones)

            self.rollout.add(self.last_obs, actions, rewards, values, dones,
                             log_probs)
            self.last_obs = obs
            rollout_step += 1

        self.rollout.compute_returns_and_advantages(values, dones=dones)

        return True

    def train(self):
        """
        Use the collected data from the buffer to train the policy network
        """
        total_losses, policy_losses, value_losses, entropy_losses = [], [], [], []

        for epoch in range(self.n_epochs):
            for batch in self.rollout.get(self.batch_size):
                observations = batch.observations
                actions = batch.actions
                old_log_probs = batch.old_log_probs
                old_values = batch.old_values
                advantages = batch.advantages
                returns = batch.returns

                # Get values and action probabilities using the updated policy on gathered observations
                state_values, action_log_probs, entropy = self.policy.evaluate(
                    observations, actions)

                # Normalize batch advantages
                advantages = (advantages -
                              advantages.mean()) / (advantages.std() + 1e-8)

                # Compute policy gradient ratio of current actions probs over previous
                ratio = torch.exp(action_log_probs - old_log_probs)
                # Compute surrogate loss
                surr_loss_1 = advantages * ratio
                surr_loss_2 = advantages * torch.clamp(
                    ratio, 1 - self.clip_range, 1 + self.clip_range)
                policy_loss = -torch.min(surr_loss_1, surr_loss_2).mean()

                # Clip state values for stability
                state_values_clipped = old_values + (
                    state_values - old_values).clamp(-self.clip_range,
                                                     self.clip_range)
                value_loss = F.mse_loss(returns, state_values).mean()
                value_loss_clipped = F.mse_loss(returns,
                                                state_values_clipped).mean()
                value_loss = torch.max(value_loss, value_loss_clipped).mean()

                # Compute entropy loss
                entropy_loss = -torch.mean(entropy)

                # Total loss
                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Perform optimization
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy.net.parameters(),
                                               self.max_grad_norm)
                self.optimizer.step()

                total_losses.append(loss.item())
                policy_losses.append(policy_loss.item())
                value_losses.append(value_loss.item())
                entropy_losses.append(entropy_loss.item())

        if self.sil:
            self.sil_module.train(4, 128, clip_range=0.2)

        logger.record("train/entropy_loss", np.mean(entropy_losses))
        logger.record("train/policy_gradient_loss", np.mean(policy_losses))
        logger.record("train/value_loss", np.mean(value_losses))
        logger.record("train/total_loss", np.mean(total_losses))

        self._n_updates += self.n_epochs

    def learn(self,
              total_timesteps,
              log_interval,
              reward_target=None,
              log_to_file=False):
        """
        Initiate the training of the algorithm.

        :param total_timesteps: (int)   total number of timesteps the agent is to run for
        :param log_interval: (int)      how often to perform logging
        :param reward_target: (int)     reaching the reward target stops training early
        :param log_to_file: (bool)      specify whether output ought to be logged
        """
        if self.sim_hash:
            logger.configure("PPO_SimHash", self.env_id, log_to_file)
        elif self.sil:
            logger.configure("PPO_SIL", self.env_id, log_to_file)
        else:
            logger.configure("PPO", self.env_id, log_to_file)

        start_time = time.time()
        iteration = 0

        while self.num_timesteps < total_timesteps:
            self.collect_samples()

            iteration += 1
            if log_interval is not None and iteration % log_interval == 0:
                logger.record("time/total timesteps", self.num_timesteps)
                if len(self.ep_info_buffer) > 0 and len(
                        self.ep_info_buffer[0]) > 0:
                    logger.record(
                        "rollout/ep_rew_mean",
                        np.mean(
                            [ep_info["r"] for ep_info in self.ep_info_buffer]))
                    logger.record("rollout/num_episodes", self.num_episodes)
                fps = int(self.num_timesteps / (time.time() - start_time))
                logger.record("time/total_time", (time.time() - start_time))
                logger.dump(step=self.num_timesteps)

            self.train()

            if reward_target is not None and np.mean(
                [ep_info["r"]
                 for ep_info in self.ep_info_buffer]) > reward_target:
                logger.record("time/total timesteps", self.num_timesteps)
                if len(self.ep_info_buffer) > 0 and len(
                        self.ep_info_buffer[0]) > 0:
                    logger.record(
                        "rollout/ep_rew_mean",
                        np.mean(
                            [ep_info["r"] for ep_info in self.ep_info_buffer]))
                    logger.record("rollout/num_episodes", self.num_episodes)
                fps = int(self.num_timesteps / (time.time() - start_time))
                logger.record("time/total_time", (time.time() - start_time))
                logger.dump(step=self.num_timesteps)

                break

        return self