コード例 #1
0
    def __init__(self, state_dim, action_dim, random_seed):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.seed = random.seed(random_seed)

        # Actor network with its target network
        self.actor_local = Actor(state_dim, action_dim, random_seed).to(device)
        self.actor_target = Actor(state_dim, action_dim,
                                  random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        # Critic network with its target network
        self.critic_local = Critic(state_dim, action_dim,
                                   random_seed).to(device)
        self.critic_target = Critic(state_dim, action_dim,
                                    random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise
        self.noise = OUNoise(action_dim, random_seed)
        self.epsilon = EPSILON

        # Replay memory
        self.memory = ReplayBuffer(action_dim, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)
コード例 #2
0
    def __init__(self,
                 output_dim,
                 input_dim,
                 name,
                 hidden=256,
                 lr_actor=1.0e-3,
                 lr_critic=1.0e-3,
                 tau=1.0e-2,
                 seed=10):
        super(DDPGAgent, self).__init__()

        self.seed = seed
        self.actor = Actor(input_dim, hidden, output_dim, seed).to(device)
        self.critic = Critic(input_dim=input_dim,
                             action_dim=output_dim,
                             hidden=hidden,
                             seed=seed,
                             output_dim=1).to(device)
        self.target_actor = Actor(input_dim, hidden, output_dim,
                                  seed).to(device)
        self.target_critic = Critic(input_dim=input_dim,
                                    action_dim=output_dim,
                                    hidden=hidden,
                                    seed=seed,
                                    output_dim=1).to(device)
        self.name = name
        self.noise = OUNoise(output_dim, seed)
        self.tau = tau
        self.epsilon = EPSILON
        self.gamma = GAMMA
        self.clipgrad = CLIPGRAD
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           lr=lr_critic,
                                           weight_decay=0)
コード例 #3
0
    def __init__(self, env, gamma, _avlambda, beta, delta, init_lr,
                 init_clip_range, n_updates, n_epochs, n_steps, n_mini_batch):
        self.env = env
        self.gamma = gamma
        self._avlambda = _avlambda
        self.beta = beta
        self.delta = delta
        self.init_lr = init_lr
        self.init_clip_range = init_clip_range
        self.lr = self.init_lr
        self.init_clip_range = self.init_clip_range
        self.n_updates = n_updates
        self.epochs = n_epochs
        self.n_steps = n_steps
        self.n_mini_batch = n_mini_batch

        self.n_episodes = 0
        self.episodes_rewards = []

        self.mini_batch_size = self.n_steps // self.n_mini_batch
        assert (self.n_steps % self.n_mini_batch == 0)

        self.current_state = None

        self.input_dim = self.env.observation_space.shape[0]
        self.actions_dim = self.env.action_space.n
        self.actor = Actor(self.input_dim, self.actions_dim).to(device)
        self.critic = Critic(self.input_dim).to(device)

        self.optim_critic = optim.Adam(self.critic.parameters(),
                                       lr=self.init_lr)
        self.optim_actor = optim.Adam(self.actor.parameters(), lr=self.init_lr)
コード例 #4
0
    def __init__(self, input_dim, n_actions, gamma, lr, beta, _lambda, _avlambda, target_freq):
        self.gamma = gamma
        self.lr = lr
        self.beta = beta
        self._lambda = _lambda
        self._avlambda = _avlambda
        self.input_dim = input_dim
        self.n_actions = n_actions

        self.actor = Actor(self.input_dim, self.n_actions).to(device)
        self.critic = Critic(self.input_dim).to(device)
        self.critic_target = deepcopy(self.critic)

        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=lr)
        self.transitions = []

        self.target_freq = target_freq

        self.learn_step = 0
        self.event_count = 0
コード例 #5
0
class DDPGAgent():
    def __init__(self, state_dim, action_dim, random_seed):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.seed = random.seed(random_seed)

        # Actor network with its target network
        self.actor_local = Actor(state_dim, action_dim, random_seed).to(device)
        self.actor_target = Actor(state_dim, action_dim,
                                  random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        # Critic network with its target network
        self.critic_local = Critic(state_dim, action_dim,
                                   random_seed).to(device)
        self.critic_target = Critic(state_dim, action_dim,
                                    random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise
        self.noise = OUNoise(action_dim, random_seed)
        self.epsilon = EPSILON

        # Replay memory
        self.memory = ReplayBuffer(action_dim, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

    def step(self, state, action, reward, next_state, done, timestamp):
        """Save experience in replay memory, and use random sample from memory to learn."""
        # Save experience
        self.memory.add(state, action, reward, next_state, done)
        # Learn (if there are enough samples in memory)
        if len(self.memory) > BATCH_SIZE and timestamp % LEARN_EVERY == 0:
            for _ in range(LEARN_NUMBER):
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA)

    def act(self, state, add_noise=True):
        """Return actions for given state from current policy."""
        state = torch.from_numpy(state).float().to(device)
        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()
        if add_noise:
            action += self.noise.sample() * self.epsilon
        return np.clip(action, -1, 1)

    def reset(self):
        self.noise.reset()

    def learn(self, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples."""
        states, actions, rewards, next_states, dones = experiences

        #   UPDATE CRITIC   #
        actions_next = self.actor_target(next_states.to(device))
        Q_targets_next = self.critic_target(next_states.to(device),
                                            actions_next.to(device))
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        clip_grad_norm_(self.critic_local.parameters(),
                        1)  # Clip the gradient when update critic network
        self.critic_optimizer.step()

        #   UPDATE ACTOR   #
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        #   UPDATE TARGET NETWORKS   #
        self.soft_update(self.critic_local, self.critic_target, RHO)
        self.soft_update(self.actor_local, self.actor_target, RHO)

        #   UPDATE EPSILON AND NOISE   #
        self.epsilon *= EPSILON_DECAY
        self.noise.reset()

    def soft_update(self, local_model, target_model, rho):
        """Soft update model parameters."""
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(rho * target_param.data +
                                    (1.0 - rho) * local_param.data)
コード例 #6
0
class DDPGAgent:
    def __init__(self,
                 output_dim,
                 input_dim,
                 name,
                 hidden=256,
                 lr_actor=1.0e-3,
                 lr_critic=1.0e-3,
                 tau=1.0e-2,
                 seed=10):
        super(DDPGAgent, self).__init__()

        self.seed = seed
        self.actor = Actor(input_dim, hidden, output_dim, seed).to(device)
        self.critic = Critic(input_dim=input_dim,
                             action_dim=output_dim,
                             hidden=hidden,
                             seed=seed,
                             output_dim=1).to(device)
        self.target_actor = Actor(input_dim, hidden, output_dim,
                                  seed).to(device)
        self.target_critic = Critic(input_dim=input_dim,
                                    action_dim=output_dim,
                                    hidden=hidden,
                                    seed=seed,
                                    output_dim=1).to(device)
        self.name = name
        self.noise = OUNoise(output_dim, seed)
        self.tau = tau
        self.epsilon = EPSILON
        self.gamma = GAMMA
        self.clipgrad = CLIPGRAD
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           lr=lr_critic,
                                           weight_decay=0)

    def act(self, state, add_noise=True):
        """Return actions for given state from current policy."""
        state = torch.from_numpy(state).float().unsqueeze(0).to(
            device)  #.unsqueeze(0)
        self.actor.eval()
        with torch.no_grad():
            action = self.actor(state).cpu().squeeze(0).data.numpy()
        self.actor.train()
        if add_noise:
            action += self.noise.sample() * self.epsilon
        return np.clip(action, -1, 1)

    def learn(self, experiences):
        """Update policy and value parameters using given batch of experience tuples."""
        states, actions, rewards, next_states, dones = experiences

        #   UPDATE CRITIC   #
        actions_next = self.target_actor(next_states.to(device))
        Q_targets_next = self.target_critic(next_states.to(device),
                                            actions_next.to(device))
        Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))
        Q_expected = self.critic(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        clip_grad_norm_(self.critic.parameters(), self.clipgrad)
        self.critic_optimizer.step()

        #   UPDATE ACTOR   #
        actions_pred = self.actor(states)
        actor_loss = -self.critic(states, actions_pred).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        #clip_grad_norm_(self.actor.parameters(), self.clipgrad)
        self.actor_optimizer.step()

        #   UPDATE TARGET NETWORKS   #
        self.soft_update(self.critic, self.target_critic)
        self.soft_update(self.actor, self.target_actor)

        #   UPDATE EPSILON AND NOISE   #
        self.epsilon *= EPSILON_DECAY
        self.noise.reset()

    def reset(self):
        self.noise.reset()

    def soft_update(self, local_model, target_model):
        """Soft update model parameters."""
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data +
                                    (1.0 - self.tau) * target_param.data)
コード例 #7
0
class PPOAgent:
    def __init__(self, env, gamma, _avlambda, beta, delta, init_lr,
                 init_clip_range, n_updates, n_epochs, n_steps, n_mini_batch):
        self.env = env
        self.gamma = gamma
        self._avlambda = _avlambda
        self.beta = beta
        self.delta = delta
        self.init_lr = init_lr
        self.init_clip_range = init_clip_range
        self.lr = self.init_lr
        self.init_clip_range = self.init_clip_range
        self.n_updates = n_updates
        self.epochs = n_epochs
        self.n_steps = n_steps
        self.n_mini_batch = n_mini_batch

        self.n_episodes = 0
        self.episodes_rewards = []

        self.mini_batch_size = self.n_steps // self.n_mini_batch
        assert (self.n_steps % self.n_mini_batch == 0)

        self.current_state = None

        self.input_dim = self.env.observation_space.shape[0]
        self.actions_dim = self.env.action_space.n
        self.actor = Actor(self.input_dim, self.actions_dim).to(device)
        self.critic = Critic(self.input_dim).to(device)

        self.optim_critic = optim.Adam(self.critic.parameters(),
                                       lr=self.init_lr)
        self.optim_actor = optim.Adam(self.actor.parameters(), lr=self.init_lr)

    def act(self, state):
        action = self.actor(state).sample()
        return action

    def act_greedy(self, state):
        action = self.actor(state).probs.argmax(dim=-1)
        return action

    @staticmethod
    def chunked_mask(mask):
        end_pts = torch.where(mask)[0]
        start = 0
        for end in end_pts:
            yield list(range(start, end + 1))
            start = end + 1

        if end_pts.nelement() == 0:
            yield list(range(len(mask)))

        elif end_pts[-1] != len(mask) - 1:
            yield list(range(end_pts[-1] + 1, len(mask)))

    def sample_trajectories(self):
        """ Sample trajectories with current policy"""

        rewards = np.zeros((self.n_steps, ), dtype=np.float32)
        actions = np.zeros((self.n_steps, ), dtype=np.int32)
        done = np.zeros((self.n_steps, ), dtype=np.bool)
        states = np.zeros((self.n_steps, self.input_dim), dtype=np.float32)
        log_pis = np.zeros((self.n_steps, ), dtype=np.float32)
        values = np.zeros((self.n_steps, ), dtype=np.float32)

        tmp_rewards = []
        self.current_state = self.env.reset()

        writer = SummaryWriter('runs/lunar/adaptative_kl/updates100')

        for t in range(self.n_steps):
            with torch.no_grad():
                states[t] = self.current_state
                state = torch.FloatTensor(self.current_state).to(device)
                pi = self.actor(state)
                values[t] = self.critic(state).cpu().numpy()
                a = self.act(state)
                actions[t] = a.cpu().numpy()
                log_pis[t] = pi.log_prob(a).cpu().numpy()

            new_state, r, end, info = self.env.step(actions[t])
            rewards[t] = r / 100.
            done[t] = False if info.get("TimeLimit.truncated") else end
            tmp_rewards.append(r)
            if end:
                if info.get("TimeLimit.truncated"):
                    print(f"Episode {self.n_episodes} Reward: ",
                          np.sum(tmp_rewards), " TimeLimit.truncated !")
                else:
                    print(f"Episode {self.n_episodes} Reward: ",
                          np.sum(tmp_rewards))

                writer.add_scalar("Reward", np.sum(tmp_rewards),
                                  self.n_episodes)

                self.n_episodes += 1
                self.episodes_rewards.append(np.sum(tmp_rewards))
                if (self.n_episodes + 1) % 100 == 0:
                    print(
                        f"Total Episodes: {self.n_episodes + 1}, "
                        f"Mean Rewards of last 100 episodes: {np.mean(self.episodes_rewards[-100:]):.2f}"
                    )
                tmp_rewards = []
                new_state = self.env.reset()

            self.current_state = new_state

        states = torch.tensor(states, dtype=torch.float32, device=device)
        actions = torch.tensor(actions, device=device)
        values = torch.tensor(values, device=device)
        rewards = torch.tensor(rewards, device=device)
        done = torch.tensor(done, device=device)
        log_pis = torch.tensor(log_pis, device=device)

        advantages = self.compute_gae(rewards, values, done)
        return states, actions, values, log_pis, advantages

    def compute_gae(self, rewards, values, dones):
        """Compute Generalized Advantage Estimates"""
        gaes = []
        last_value = self.critic(
            torch.FloatTensor(self.current_state).to(device))
        next_values = torch.cat([values[1:], last_value.detach()])

        for chunk in self.chunked_mask(dones):
            T = rewards[chunk].size(0)
            td_target = rewards[chunk] + self.gamma * (
                1 - dones[chunk].int()) * next_values[chunk]
            td_error = td_target - values[chunk]
            discount = ((self.gamma *
                         self._avlambda)**torch.arange(T)).to(device)
            gae = torch.tensor([(discount[:T - t] * td_error[t:]).sum()
                                for t in range(T)],
                               dtype=torch.float32).to(device)
            gaes.append(gae)
        return torch.cat(gaes)

    def optimize_adaptative_kl(self, states, actions, values, log_pis,
                               advantages, learning_rate, clip_range):
        old_pis = self.actor(states)
        old_probs = old_pis.probs.detach()

        for _ in tqdm(range(self.epochs)):
            # shuffle for each epoch
            indexes = torch.randperm(self.n_steps)
            # for each mini batch
            for start in range(0, self.n_steps, self.mini_batch_size):

                # get mini batch
                idx = indexes[start:start + self.mini_batch_size]

                # compute loss
                lambda_returns = values[idx] + advantages[idx]
                A_old = (advantages[idx] - advantages[idx].mean()) / (
                    advantages[idx].std() + 1e-8)
                new_pis = self.actor(states[idx])
                new_values = self.critic(states[idx])

                # policy loss (with entropy)
                new_log_pis = new_pis.log_prob(actions[idx])
                L = ((new_log_pis - log_pis[idx]).exp() * A_old).mean()
                kl = kl_divergence(Categorical(old_probs[idx]), new_pis).mean()
                entropy = new_pis.entropy().mean()
                policy_loss = -(L - self.beta * kl + 0.01 * entropy)

                # value loss
                clipped_value = values[idx] + (new_values - values[idx]).clamp(
                    min=-clip_range, max=clip_range)
                vf_loss = torch.max((new_values - lambda_returns)**2,
                                    (clipped_value - lambda_returns)**2)
                vf_loss = vf_loss.mean()

                self.optim_actor.zero_grad()
                self.optim_critic.zero_grad()

                for pg in self.optim_actor.param_groups:
                    pg['lr'] = learning_rate

                for pg in self.optim_critic.param_groups:
                    pg['lr'] = learning_rate

                policy_loss.backward()
                vf_loss.backward()

                clip_grad_norm_(self.actor.parameters(), max_norm=0.5)
                clip_grad_norm_(self.critic.parameters(), max_norm=0.5)
                self.optim_actor.step()
                self.optim_critic.step()

        new_pis = self.actor(states)
        kl = kl_divergence(Categorical(old_probs), new_pis).mean()
        if kl >= 1.5 * self.delta:
            self.beta *= 2
        elif kl <= self.delta / 1.5:
            self.beta /= 2

    def optimize_clipped(self, states, actions, values, log_pis, advantages,
                         learning_rate, clip_range):

        for _ in tqdm(range(self.epochs)):
            # shuffle for each epoch
            indexes = torch.randperm(self.n_steps)

            # for each mini batch
            for start in range(0, self.n_steps, self.mini_batch_size):

                # get mini batch
                idx = indexes[start:start + self.mini_batch_size]

                # compute loss
                lambda_returns = values[idx] + advantages[idx]
                A_old = (advantages[idx] - advantages[idx].mean()) / (
                    advantages[idx].std() + 1e-8)
                new_pis = self.actor(states[idx])
                new_values = self.critic(states[idx])

                # policy loss (with entropy)
                new_log_pis = new_pis.log_prob(actions[idx])
                ratio = (new_log_pis - log_pis[idx]).exp()
                clipped_ratio = ratio.clamp(min=1.0 - clip_range,
                                            max=1.0 + clip_range)
                Lclip = torch.min(ratio * A_old, clipped_ratio * A_old)
                Lclip = Lclip.mean()
                entropy = new_pis.entropy().mean()
                policy_loss = -(Lclip + 0.01 * entropy)

                # value loss
                clipped_value = values[idx] + (new_values - values[idx]).clamp(
                    min=-clip_range, max=clip_range)
                vf_loss = torch.max((new_values - lambda_returns)**2,
                                    (clipped_value - lambda_returns)**2)
                vf_loss = vf_loss.mean()

                self.optim_actor.zero_grad()
                self.optim_critic.zero_grad()

                for pg in self.optim_actor.param_groups:
                    pg['lr'] = learning_rate

                for pg in self.optim_critic.param_groups:
                    pg['lr'] = learning_rate

                policy_loss.backward()
                vf_loss.backward()

                clip_grad_norm_(self.actor.parameters(), max_norm=0.5)
                clip_grad_norm_(self.critic.parameters(), max_norm=0.5)
                self.optim_actor.step()
                self.optim_critic.step()

    def optimize_without(self, states, actions, values, log_pis, advantages,
                         learning_rate, clip_range):
        old_pis = self.actor(states)
        old_probs = old_pis.probs.detach()

        for _ in tqdm(range(self.epochs)):
            # shuffle for each epoch
            indexes = torch.randperm(self.n_steps)
            # for each mini batch
            for start in range(0, self.n_steps, self.mini_batch_size):

                # get mini batch
                idx = indexes[start:start + self.mini_batch_size]

                # compute loss
                lambda_returns = values[idx] + advantages[idx]
                A_old = (advantages[idx] - advantages[idx].mean()) / (
                    advantages[idx].std() + 1e-8)
                new_pis = self.actor(states[idx])
                new_values = self.critic(states[idx])

                # policy loss (with entropy)
                new_log_pis = new_pis.log_prob(actions[idx])
                L = ((new_log_pis - log_pis[idx]).exp() * A_old).mean()
                entropy = new_pis.entropy().mean()
                policy_loss = -(L + 0.01 * entropy)

                # value loss
                clipped_value = values[idx] + (new_values - values[idx]).clamp(
                    min=-clip_range, max=clip_range)
                vf_loss = torch.max((new_values - lambda_returns)**2,
                                    (clipped_value - lambda_returns)**2)
                vf_loss = vf_loss.mean()

                self.optim_actor.zero_grad()
                self.optim_critic.zero_grad()

                for pg in self.optim_actor.param_groups:
                    pg['lr'] = learning_rate

                for pg in self.optim_critic.param_groups:
                    pg['lr'] = learning_rate

                policy_loss.backward()
                vf_loss.backward()

                clip_grad_norm_(self.actor.parameters(), max_norm=0.5)
                clip_grad_norm_(self.critic.parameters(), max_norm=0.5)
                self.optim_actor.step()
                self.optim_critic.step()

        new_pis = self.actor(states)
コード例 #8
0
class ActorCriticAgent(object):
    """Actor-Critic agent"""

    def __init__(self, input_dim, n_actions, gamma, lr, beta, _lambda, _avlambda, target_freq):
        self.gamma = gamma
        self.lr = lr
        self.beta = beta
        self._lambda = _lambda
        self._avlambda = _avlambda
        self.input_dim = input_dim
        self.n_actions = n_actions

        self.actor = Actor(self.input_dim, self.n_actions).to(device)
        self.critic = Critic(self.input_dim).to(device)
        self.critic_target = deepcopy(self.critic)

        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=lr)
        self.transitions = []

        self.target_freq = target_freq

        self.learn_step = 0
        self.event_count = 0

    def act(self, state):
        action = self.actor(state).sample()
        return action.item()

    def act_greedy(self, state):
        action = self.actor(state).probs.argmax(dim=-1)
        return action.item()

    def store_transition(self, transition):
        if len(self.transitions) == 0:
            self.transitions.append([])
        self.transitions[-1].append(transition)

    @staticmethod
    def _get_transitions(transitions):
        states, actions, rewards, next_states, dones = [], [], [], [], []
        for transition in transitions:
            states.append(transition[0])
            actions.append(transition[1])
            rewards.append(transition[2] / 100.0)
            next_states.append(transition[3])
            dones.append(int(transition[4]))

        return (
            torch.stack(states),
            torch.tensor(actions, dtype=torch.int64).to(device),
            torch.FloatTensor(rewards).to(device),
            torch.stack(next_states),
            torch.IntTensor(dones).to(device),
        )

    def get_transitions(self, shuffle=True):
        g_states, g_actions, g_rewards, g_next_states, g_dones, g_gaes = [], [], [], [], [], []

        for transitions in self.transitions:
            # get transitions for current trajectory
            states, actions, rewards, next_states, dones = self._get_transitions(transitions)
            # compute GAE for each trajectory
            gaes = self.compute_gae(rewards, states, next_states, 1-dones)
            g_states.append(states)
            g_actions.append(actions)
            g_rewards.append(rewards)
            g_next_states.append(next_states)
            g_dones.append(dones)
            g_gaes.append(gaes)

        self.transitions = []
        if shuffle:
            idx = torch.randperm(torch.cat(g_actions).size(0))
        else:
            idx = torch.arange(torch.cat(g_actions).size(0))
        return (
            torch.cat(g_states)[idx],
            torch.cat(g_actions)[idx],
            torch.cat(g_rewards)[idx],
            torch.cat(g_next_states)[idx],
            torch.cat(g_dones)[idx],
            torch.cat(g_gaes)[idx],
        )

    def compute_gae(self, rewards, states, next_states, masks):
        T = rewards.size(0)
        values = self.critic(states).view(-1)
        next_values = self.critic(next_states).view(-1)
        td_target = rewards + self.gamma * masks * next_values
        td_error = td_target - values

        discount = ((self.gamma * self._avlambda) ** torch.arange(T)).to(device)
        gae = torch.Tensor([(discount[:T - t] * td_error[t:]).sum() for t in range(T)]).to(device)
        return gae.detach()

    def optimize(self):
        self.learn_step += 1
        states, actions, rewards, next_states, dones, gaes = self.get_transitions()

        lambda_returns = self.critic(states).view(-1) + gaes

        # value loss
        value_loss = F.mse_loss(self.critic(states).view(-1), lambda_returns.detach())

        # policy loss
        dist = self.actor(states)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy().mean()
        policy_loss = - (log_probs * gaes.detach()).mean() - self.beta * entropy

        # backpropagation
        self.critic_optim.zero_grad()
        value_loss.backward()
        clip_grad_norm_(self.critic.parameters(), 0.1)
        self.critic_optim.step()

        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()
        clip_grad_norm_(self.actor.parameters(), 0.1)
        return policy_loss, value_loss, gaes, entropy