Beispiel #1
0
    def backward(self, sample_):
        self.replay_buffer.push(sample_)
        if self.step > self.learning_starts and self.learning:
            sample = self.replay_buffer.sample(self.batch_size)
            if self.gpu:
                for key in sample.keys():
                    sample[key] = sample[key].cuda()
            assert len(sample["s"]) == self.batch_size
            "update the critic "
            if self.step % self.critic_training_freq == 0:
                target_a = self.target_actor(sample["s_"])
                target_input = torch.cat((sample["s_"], target_a), -1)
                Q1, Q2 = self.target_critic(target_input)
                target_Q = torch.min(Q1, Q2)
                expected_q_values = sample["r"] + self.gamma * target_Q * (1.0 - sample["tr"])

                input = torch.cat((sample["s"], sample["a"]), -1)
                Q1, Q2 = self.critic(input)
                loss = torch.mean(huber_loss(expected_q_values - Q1))+torch.mean(huber_loss(expected_q_values - Q2))
                self.critic.zero_grad()
                loss.backward()
                self.critic_optim.step()
            "training the actor"
            if self.step % self.actor_training_freq == 0:
                Q = self.actor_critic(sample["s"])
                Q = -torch.mean(Q)
                self.actor.zero_grad()
                Q.backward()
                self.actor_optim.step()
            self.target_net_update()
            loss = loss.data.numpy()
            return loss, {}
        return 0, {}
Beispiel #2
0
 def backward(self, sample_):
     self.replay_buffer.push(sample_)
     if self.step > self.learning_starts and self.learning:
         sample = self.replay_buffer.sample(self.batch_size)
         if self.gpu:
             for key in sample.keys():
                 sample[key] = sample[key].cuda()
         assert len(sample["s"]) == self.batch_size
         a = sample["a"].long().unsqueeze(1)
         Q = self.Q_net(sample["s"]).gather(1, a)
         if self.double_dqn:
             _, next_actions = self.Q_net(sample["s_"]).max(1, keepdim=True)
             targetQ = self.target_Q_net(sample["s_"]).gather(
                 1, next_actions)
         else:
             _, next_actions = self.target_Q_net(sample["s_"]).max(
                 1, keepdim=True)
             targetQ = self.target_Q_net(sample["s_"]).gather(
                 1, next_actions)
         targetQ = targetQ.squeeze(1)
         Q = Q.squeeze(1)
         expected_q_values = sample["r"] + self.gamma * targetQ * (
             1.0 - sample["tr"])
         loss = torch.mean(huber_loss(expected_q_values - Q))
         self.optim.zero_grad()
         loss.backward()
         torch.nn.utils.clip_grad_norm_(self.Q_net.parameters(),
                                        1,
                                        norm_type=2)
         self.optim.step()
         if self.step % self.target_network_update_freq == 0:
             self.target_net_update()
         loss = loss.data.numpy()
         return loss, {}
     return 0, {}
Beispiel #3
0
 def backward(self, sample_):
     self.replay_buffer.push(sample_)
     if self.step > self.learning_starts and self.learning:
         sample = self.replay_buffer.sample(self.batch_size)
         if self.gpu:
             for key in sample.keys():
                 sample[key] = sample[key].cuda()
         assert len(sample["s"]) == self.batch_size
         "update the critic "
         if self.step % self.critic_training_freq == 0:
             if self.sperate_critic:
                 Q = self.critic.forward(sample["s"], sample["a"])
             else:
                 input = torch.cat((sample["s"], sample["a"]), -1)
                 Q = self.critic.forward(input)
             target_a = self.target_actor(sample["s_"])
             if self.sperate_critic:
                 targetQ = self.target_critic(sample["s_"], target_a)
             else:
                 target_input = torch.cat((sample["s_"], target_a), -1)
                 targetQ = self.target_critic(target_input)
             targetQ = targetQ.squeeze(1)
             Q = Q.squeeze(1)
             expected_q_values = sample["r"] + self.gamma * targetQ * (
                 1.0 - sample["tr"])
             loss = torch.mean(huber_loss(expected_q_values - Q))
             self.critic_optim.zero_grad()
             loss.backward()
             torch.nn.utils.clip_grad_norm_(self.critic.parameters(),
                                            1,
                                            norm_type=2)
             self.critic_optim.step()
         "training the actor"
         if self.step % self.actor_training_freq == 0:
             Q = self.actor_critic.forward(sample["s"])
             Q = -torch.mean(Q)
             self.actor_optim.zero_grad()
             Q.backward()
             torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
                                            1,
                                            norm_type=2)
             self.actor_optim.step()
         if self.step % self.actor_target_network_update_freq == 0:
             self.target_actor_net_update()
         if self.step % self.critic_target_network_update_freq == 0:
             self.target_critic_net_update()
         loss = loss.data.numpy()
         return loss, {}
     return 0, {}