Exemple #1
0
    def update_actor_temp(self, states, actions, rewards, next_states, dones):

        for p in self.sac_net.target.parameters():
            p.requires_grad = False
        for p in self.sac_net.critic.parameters():
            p.requires_grad = False

        # update actor:
        actions, log_probs, aux_losses = self.sac_net.sample(states,
                                                             training=True)
        q1, q2 = self.sac_net.critic(states, actions)
        q_old = torch.min(q1, q2)
        actor_loss = (self.sac_net.alpha.detach() * log_probs - q_old).mean()
        aux_losses = compute_sum_aux_losses(aux_losses)
        overall_loss = actor_loss + aux_losses
        self.actor_optimizer.zero_grad()
        overall_loss.backward()
        self.actor_optimizer.step()

        # update temp:
        temp_loss = (self.sac_net.log_alpha.exp() *
                     (-log_probs.detach().mean() + self.action_size).detach())
        self.log_alpha_optimizer.zero_grad()
        temp_loss.backward()
        self.log_alpha_optimizer.step()
        self.sac_net.alpha.data = self.sac_net.log_alpha.exp().detach()

        for p in self.sac_net.target.parameters():
            p.requires_grad = True
        for p in self.sac_net.critic.parameters():
            p.requires_grad = True

        return actor_loss, temp_loss
Exemple #2
0
    def update_critic(self, states, actions, rewards, next_states, dones):

        q1_current, q2_current, aux_losses = self.sac_net.critic(states,
                                                                 actions,
                                                                 training=True)
        with torch.no_grad():
            next_actions, log_probs, _ = self.sac_net.sample(next_states)
            q1_next, q2_next = self.sac_net.target(next_states, next_actions)
            v_next = (torch.min(q1_next, q2_next) -
                      self.sac_net.alpha.detach() * log_probs)
            q_target = (rewards + ((1 - dones) * self.gamma * v_next)).detach()

        critic_loss = F.mse_loss(q1_current, q_target) + F.mse_loss(
            q2_current, q_target)

        aux_losses = compute_sum_aux_losses(aux_losses)
        overall_loss = critic_loss + aux_losses
        self.critic_optimizer.zero_grad()
        overall_loss.backward()
        self.critic_optimizer.step()

        return critic_loss
Exemple #3
0
    def learn(self):
        output = {}
        states, actions, rewards, next_states, dones, others = self.memory.sample(
            device=self.device)
        actions = actions.squeeze(dim=1)
        next_actions = self.actor_target(next_states)
        noise = torch.randn_like(next_actions).mul(self.policy_noise)
        noise = noise.clamp(-self.noise_clip, self.noise_clip)
        next_actions += noise
        next_actions = torch.max(
            torch.min(next_actions, self.action_high.to(self.device)),
            self.action_low.to(self.device),
        )

        target_Q1 = self.critic_1_target(next_states, next_actions)
        target_Q2 = self.critic_2_target(next_states, next_actions)
        target_Q = torch.min(target_Q1, target_Q2)
        target_Q = (rewards + ((1 - dones) * self.gamma * target_Q)).detach()

        # Optimize Critic 1:
        current_Q1, aux_losses_Q1 = self.critic_1(states,
                                                  actions,
                                                  training=True)
        loss_Q1 = F.mse_loss(current_Q1,
                             target_Q) + compute_sum_aux_losses(aux_losses_Q1)
        self.critic_1_optimizer.zero_grad()
        loss_Q1.backward()
        self.critic_1_optimizer.step()

        # Optimize Critic 2:
        current_Q2, aux_losses_Q2 = self.critic_2(states,
                                                  actions,
                                                  training=True)
        loss_Q2 = F.mse_loss(current_Q2,
                             target_Q) + compute_sum_aux_losses(aux_losses_Q2)
        self.critic_2_optimizer.zero_grad()
        loss_Q2.backward()
        self.critic_2_optimizer.step()

        # delayed actor updates
        if (self.step_count + 1) % self.policy_delay == 0:
            critic_out = self.critic_1(states,
                                       self.actor(states),
                                       training=True)
            actor_loss, actor_aux_losses = -critic_out[0], critic_out[1]
            actor_loss = actor_loss.mean() + compute_sum_aux_losses(
                actor_aux_losses)
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            self.soft_update(self.actor_target, self.actor, self.actor_tau)

            self.num_actor_updates += 1
            output = {
                "loss/critic_1": {
                    "type": "scalar",
                    "data": loss_Q1.data.cpu().numpy(),
                    "freq": 10,
                },
                "loss/actor": {
                    "type": "scalar",
                    "data": actor_loss.data.cpu().numpy(),
                    "freq": 10,
                },
            }
        self.soft_update(self.critic_1_target, self.critic_1, self.critic_tau)
        self.soft_update(self.critic_2_target, self.critic_2, self.critic_tau)
        self.current_iteration += 1
        return output
Exemple #4
0
    def update(self, n_epochs, mini_batch_size, states, actions, log_probs,
               returns, advantages):
        total_actor_loss = 0
        total_critic_loss = 0
        total_entropy_loss = 0
        # multiple epochs
        for _ in range(n_epochs):
            # minibatch updates
            for (
                    state,
                    action,
                    old_pi_log_probs,
                    return_batch,
                    advantage,
            ) in self.get_minibatch(mini_batch_size, states, actions,
                                    log_probs, returns, advantages):
                (dist, value), aux_losses = self.ppo_net(state, training=True)
                entropy = dist.entropy().mean()  # L_S
                new_pi_log_probs = dist.log_prob(action)

                ratio = self.get_ratio(new_pi_log_probs, old_pi_log_probs)
                L_CPI = ratio * advantage
                clipped_version = (
                    torch.clamp(ratio, 1.0 - self.eps, 1.0 + self.eps) *
                    advantage)

                # loss and clipping
                actor_loss = -torch.min(L_CPI,
                                        clipped_version).mean()  # L_CLIP
                critic_loss = ((return_batch - value).pow(2).mean()
                               )  # L_VF (squared error loss)

                aux_losses = compute_sum_aux_losses(aux_losses)

                # overall loss
                loss = (self.critic_tau * critic_loss +
                        self.actor_tau * actor_loss -
                        self.entropy_tau * entropy + aux_losses)

                # calculate gradients and update the weights
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_actor_loss += actor_loss.item()
                total_critic_loss += critic_loss.item()
                total_entropy_loss += entropy.item()

        average_actor_loss = total_actor_loss / (
            n_epochs * (self.batch_size / self.mini_batch_size))
        average_critic_loss = total_critic_loss / (
            n_epochs * (self.batch_size / self.mini_batch_size))
        average_entropy_loss = total_entropy_loss / (
            n_epochs * (self.batch_size / self.mini_batch_size))

        output = {
            "loss/critic": {
                "type": "scalar",
                "data": average_critic_loss,
                "freq": self.logging_freq,
            },
            "loss/actor": {
                "type": "scalar",
                "data": average_actor_loss,
                "freq": self.logging_freq,
            },
            "loss/entropy": {
                "type": "scalar",
                "data": average_entropy_loss,
                "freq": self.logging_freq,
            },
        }

        return output