Example #1
0
    def update_parameters(self, state_batch, next_state_batch, action_batch,
                          reward_batch, done_batch):

        state_batch = state_batch.to(self.device)
        next_state_batch = next_state_batch.to(self.device)
        action_batch = action_batch.to(self.device)
        reward_batch = reward_batch.to(self.device)
        done_batch = done_batch.to(self.device)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _, _, _ = self.actor.noisy_action(
                next_state_batch, return_only_action=False)
            qf1_next_target, qf2_next_target, _ = self.critic_target.forward(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + self.gamma * (min_qf_next_target) * (
                1 - done_batch)
            self.writer.add_scalar('next_q', next_q_value.mean().item())

        qf1, qf2, _ = self.critic.forward(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        self.writer.add_scalar('q_loss',
                               (qf1_loss + qf2_loss).mean().item() / 2.0)

        pi, log_pi, _, _, _ = self.actor.noisy_action(state_batch,
                                                      return_only_action=False)
        self.writer.add_scalar('log_pi', log_pi.mean().item())

        qf1_pi, qf2_pi, _ = self.critic.forward(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        self.writer.add_scalar('policy_q', min_qf_pi.mean().item())

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        self.writer.add_scalar('policy_loss', policy_loss.mean().item())

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        qf2_loss.backward()
        self.critic_optim.step()

        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()

        self.num_updates += 1
        soft_update(self.critic_target, self.critic, self.tau)
Example #2
0
    def update(self, batch, policy_update=False):
        with torch.no_grad():
            state_batch = batch['states']
            action_batch = batch['actions']
            next_state_batch = batch['next_states']
            reward_batch = batch['rewards']
            done_batch = batch['dones']

        # Target policy smoothing, by adding clipped noise to target actions
        noise = np.clip(
            np.random.normal(0,
                             self.target_noise,
                             size=(self.batch_size, self.action_dim)),
            -self.noise_clip, self.noise_clip)
        next_action = self.actor_target(next_state_batch) + to_tensor(noise)
        next_action_clip = next_action.clamp(-self.action_limit,
                                             self.action_limit)
        with torch.no_grad():
            q1_next, q2_next = self.critic_target(next_state_batch,
                                                  next_action_clip)
            min_q_next = torch.min(q1_next, q2_next)
            # compute q_target and two q predict
            q_target = reward_batch + self.gamma * (
                1 - done_batch.float()) * min_q_next
        q1_predict, q2_predict = self.critic(state_batch, action_batch)

        # critic update
        self.critic_optim.zero_grad()
        critic_loss = self.loss(q1_predict, q_target) + self.loss(
            q2_predict, q_target)
        critic_loss.backward(retain_graph=True)
        nn.utils.clip_grad_norm_(self.critic.parameters(), 10)
        self.critic_optim.step()

        if policy_update:
            # Delayed policy update
            self.actor_optim.zero_grad()
            q1, _ = self.critic(state_batch, self.actor(state_batch))
            actor_loss = -q1.mean()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor.parameters(), 10)
            self.actor_optim.step()

            # actor/critic network soft update
            soft_update(self.actor_target, self.actor, self.tau)
            soft_update(self.critic_target, self.critic, self.tau)
Example #3
0
File: sac.py Project: marsXyr/ERL
    def update(self, batch):
        with torch.no_grad():
            state_batch = batch['states']
            action_batch = batch['actions']
            next_state_batch = batch['next_states']
            reward_batch = batch['rewards']
            done_batch = batch['dones']
        # Compute q_target
        q1, q2 = self.Q(state_batch, action_batch)
        v_next_target = self.V_target(next_state_batch)
        q_target = reward_batch + self.gamma * (1 - done_batch.float()) * v_next_target.detach()

        # Compute v_target
        sample_actions, log_prob, _, mean, log_std = self.P.sample(state_batch)
        q1_new, q2_new = self.Q(state_batch, sample_actions)
        min_q_new = torch.min(q1_new, q2_new)
        v_target = min_q_new - (self.alpha * log_prob)
        v = self.V(state_batch)

        # q network update
        self.q_optim.zero_grad()
        q_loss = self.loss(q1, q_target) + self.loss(q2, q_target)
        q_loss.backward(retain_graph=True)
        nn.utils.clip_grad_norm_(self.Q.parameters(), 10)
        self.q_optim.step()

        # v network update
        self.v_optim.zero_grad()
        v_loss = self.loss(v, v_target.detach())
        v_loss.backward(retain_graph=True)
        nn.utils.clip_grad_norm_(self.V.parameters(), 10)
        self.v_optim.step()

        """
        Reparameterization trick is used to get a low variance estimator
        """
        # policy network update
        self.p_optim.zero_grad()
        mean_loss = 0.001 * mean.pow(2).mean()
        std_loss = 0.001 * log_std.pow(2).mean()
        p_loss = (self.alpha * log_prob - min_q_new).mean() + mean_loss + std_loss
        p_loss.backward()
        nn.utils.clip_grad_norm_(self.P.parameters(), 10)
        self.p_optim.step()

        soft_update(self.V_target, self.V, self.tau)
    def update_parameters(self, state_batch, next_state_batch, action_batch,
                          reward_batch, done_batch):

        state_batch = state_batch.to(self.device)
        next_state_batch = next_state_batch.to(self.device)
        action_batch = action_batch.to(self.device)
        reward_batch = reward_batch.to(self.device)
        done_batch = done_batch.to(self.device)

        action_batch = action_batch.long().unsqueeze(1)
        with torch.no_grad():
            na = self.actor.clean_action(next_state_batch,
                                         return_only_action=True)
            _, _, ns_logits = self.actor_target.noisy_action(
                next_state_batch, return_only_action=False)
            next_entropy = -(F.softmax(ns_logits, dim=1) * F.log_softmax(
                ns_logits, dim=1)).mean(1).unsqueeze(1)

            ns_logits = ns_logits.gather(1, na.unsqueeze(1))

            next_target = ns_logits + self.alpha * next_entropy
            next_q_value = reward_batch + (
                1 - done_batch) * self.gamma * next_target

        _, _, logits = self.actor.noisy_action(state_batch,
                                               return_only_action=False)
        entropy = -(F.softmax(logits, dim=1) *
                    F.log_softmax(logits, dim=1)).mean(1).unsqueeze(1)
        q_val = logits.gather(1, action_batch)

        q_loss = (next_q_value - q_val)**2
        q_loss -= self.alpha * entropy
        q_loss = q_loss.mean()

        self.actor_optim.zero_grad()
        q_loss.backward()
        self.actor_optim.step()

        self.num_updates += 1
        soft_update(self.actor_target, self.actor, self.tau)
Example #5
0
File: td3.py Project: ShawK91/l2m
    def update_parameters(self,
                          state_batch,
                          next_state_batch,
                          action_batch,
                          reward_batch,
                          done_batch,
                          num_epoch=1):
        """Runs a step of Bellman upodate and policy gradient using a batch of experiences

             Parameters:
                  state_batch (tensor): Current States
                  next_state_batch (tensor): Next States
                  action_batch (tensor): Actions
                  reward_batch (tensor): Rewards
                  done_batch (tensor): Done batch
                  num_epoch (int): Number of learning iteration to run with the same data

             Returns:
                   None

         """

        if isinstance(state_batch, list):
            state_batch = torch.cat(state_batch)
            next_state_batch = torch.cat(next_state_batch)
            action_batch = torch.cat(action_batch)
            reward_batch = torch.cat(reward_batch).done_batch = torch.cat(
                done_batch)

        for _ in range(num_epoch):
            ########### CRITIC UPDATE ####################

            #Compute next q-val, next_v and target
            with torch.no_grad():
                #Policy Noise
                policy_noise = np.random.normal(
                    0, self.policy_noise,
                    (action_batch.size()[0], action_batch.size()[1]))
                policy_noise = torch.clamp(
                    torch.Tensor(policy_noise),
                    -self.policy_noise_clip,
                    self.policy_noise_clip,
                )
                if torch.cuda.is_available():
                    policy_noise = policy_noise.cuda()

                #Compute next action_bacth
                next_action_batch = self.actor_target.clean_action(
                    next_state_batch) + policy_noise
                next_action_batch = torch.clamp(next_action_batch, -1, 1)

                #Compute Q-val and value of next state masking by done
                q1, q2, _ = self.critic_target.forward(next_state_batch,
                                                       next_action_batch)
                q1 = (1 - done_batch) * q1
                q2 = (1 - done_batch) * q2

                #Select which q to use as next-q (depends on algo)
                next_q = torch.min(q1, q2)

                #Compute target q and target val
                target_q = reward_batch + (self.gamma * next_q)

            self.critic_optim.zero_grad()
            current_q1, current_q2, current_val = self.critic.forward(
                state_batch, action_batch)
            self.compute_stats(current_q1, self.q)

            dt = self.loss(current_q1, target_q)

            dt = dt + self.loss(current_q2, target_q)
            self.critic_loss['mean'].append(dt.item())

            dt.backward()

            self.critic_optim.step()
            self.num_critic_updates += 1

            #Delayed Actor Update
            if self.num_critic_updates % self.policy_ups_freq == 0:

                actor_actions = self.actor.clean_action(state_batch)
                Q1, Q2, val = self.critic.forward(state_batch, actor_actions)

                # if self.args.use_advantage: policy_loss = -(Q1 - val)
                policy_loss = -Q1

                self.compute_stats(policy_loss, self.policy_loss)
                policy_loss = policy_loss.mean()

                self.actor_optim.zero_grad()

                policy_loss.backward(retain_graph=True)
                self.actor_optim.step()

            if self.num_critic_updates % self.policy_ups_freq == 0:
                utils.soft_update(self.actor_target, self.actor, self.tau)
            utils.soft_update(self.critic_target, self.critic, self.tau)
Example #6
0

        
Example #7
0
File: sac.py Project: ShawK91/l2m
    def update_parameters(self, state_batch, next_state_batch, action_batch,
                          reward_batch, done_batch):

        with torch.no_grad():
            next_state_action, next_state_log_pi, _, _, _ = self.actor.noisy_action(
                next_state_batch, return_only_action=False)
            qf1_next_target, qf2_next_target, _ = self.critic_target.forward(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
            if self.sac_kwargs['entropy']:
                min_qf_next_target -= self.alpha * next_state_log_pi
            next_q_value = reward_batch + (1 - done_batch) * self.gamma * (
                min_qf_next_target)
            self.compute_stats(next_state_log_pi, self.next_entropy)

        qf1, qf2, _ = self.critic.forward(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        self.compute_stats(qf1_loss, self.critic_loss)

        pi, log_pi, _, _, _ = self.actor.noisy_action(state_batch,
                                                      return_only_action=False)
        self.compute_stats(log_pi, self.entropy)

        qf1_pi, qf2_pi, _ = self.critic.forward(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        self.compute_stats(min_qf_pi, self.policy_q)

        policy_loss = -min_qf_pi  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        if self.sac_kwargs['entropy']:
            policy_loss += self.alpha * log_pi
        policy_loss = policy_loss.mean()

        self.critic_optim.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward(retain_graph=True)
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        self.num_updates += 1
        if self.num_updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)
Example #8
0
    def update_parameters(self, state_batch, next_state_batch, action_batch,
                          reward_batch, done_batch):

        action_batch = action_batch.long()
        with torch.no_grad():
            na = self.actor.clean_action(next_state_batch,
                                         return_only_action=True)
            _, _, ns_logits = self.actor_target.clean_action(
                next_state_batch, return_only_action=False)

            #Compute Duelling Q-Val
            #next_target = self.compute_duelling_out(ns_logits, na)
            next_target = [
                ns_logits[:, i:i + 3][na[:, i]][:, -1:] for i in range(22)
            ]
            next_target = torch.cat(next_target, axis=1)

            #Entropy
            #next_target -= self.alpha * ns_log_prob.unsqueeze(1)

            next_q_value = reward_batch + (
                1 - done_batch) * self.gamma * next_target
            #self.compute_stats(ns_log_prob, self.next_entropy)
            self.compute_stats(next_q_value, self.next_q)

        # Compute Duelling Q-Val
        _, _, logits = self.actor.clean_action(state_batch,
                                               return_only_action=False)
        q_val = [
            logits[:, i:i + 3][action_batch[:, i]][:, -1:] for i in range(22)
        ]
        q_val = torch.cat(q_val, axis=1)
        #q_val = self.compute_duelling_out(logits, action_batch)
        #self.compute_stats(log_prob, self.entropy)
        self.compute_stats(q_val, self.policy_q)

        loss_function = torch.nn.MSELoss()
        q_loss = loss_function(
            next_q_value, q_val
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        self.compute_stats(q_loss, self.critic_loss)

        self.actor_optim.zero_grad()
        q_loss.backward()
        self.actor_optim.step()

        # if self.automatic_entropy_tuning:
        #     alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
        #
        #     self.alpha_optim.zero_grad()
        #     alpha_loss.backward(retain_graph=True)
        #     self.alpha_optim.step()
        #
        #     self.alpha = self.log_alpha.exp()
        #     alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        # else:
        #     alpha_loss = torch.tensor(0.)
        #     alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        self.num_updates += 1
        if self.num_updates % self.target_update_interval == 0:
            soft_update(self.actor_target, self.actor, self.tau)