예제 #1
0
파일: sac.py 프로젝트: uenian33/avant_RL
    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)

        qf1, qf2 = self.critic(state_batch, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs


        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()
예제 #2
0
    def _take_step(self, indices, context):

        num_tasks = len(indices)

        # data is (task, batch, feat)
        obs, actions, rewards, next_obs, terms = self.sample_data(
            indices)  #从Replay Buffer采集数据,s,a,r,s',d

        # run inference in networks
        # policy_outputs, task_z = self.agent(obs, context)#策略forward的输出,以及任务隐变量Z
        policy_outputs, task_z = self.agent(obs,
                                            context)  # 策略forward的输出,以及任务隐变量Z
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:
                                                                          4]  #下一个状态下策略所采取的动作,其log概率 line 63
        # flattens out the task dimension:
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
        # scale rewards for Bellman update
        rewards_flat = rewards_flat * self.reward_scale
        terms_flat = terms.view(self.batch_size * num_tasks, -1)

        with torch.no_grad():
            q1_next_target, q2_next_target = self.critic_target(
                next_obs, new_actions,
                task_z)  # target q                     line 64
            min_qf_next_target = torch.min(
                q1_next_target, q2_next_target
            ) - self.alpha * log_pi  #计算较小的target Q         line 65
            next_q_value = rewards_flat + (
                1. - terms_flat
            ) * self.discount * min_qf_next_target  #q=r+(1-d)γ(Vst+1)         line 66
        q1, q2 = self.critic(
            obs, actions, task_z
        )  # forward                                                           line 68
        q1_loss = F.mse_loss(
            q1, next_q_value
        )  #JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]      line 69
        q2_loss = F.mse_loss(
            q2, next_q_value
        )  #                                                                         line 70
        #pi, log_pi, _ = self.agent.policy.sample(obs)  # 动作,动作的对数概率
        #print(obs.size())  #[1024,27]
        #print(task_z.size())
        in_policy = torch.cat([obs, task_z], 1)
        pi, _, _, log_pi, _, _, _, _, = self.agent.policy(
            in_policy
        )  #                                                 line 72

        q1_pi, q2_pi = self.critic(
            obs, pi, task_z.detach()
        )  # 动作的Q值                                                        line 74
        min_q_pi = torch.min(
            q1_pi, q2_pi
        )  #                                                                            line 75

        # KL constraint on z if probabilistic
        self.context_optimizer.zero_grad()
        if self.use_information_bottleneck:
            kl_div = self.agent.compute_kl_div()
            kl_loss = self.kl_lambda * kl_div
            kl_loss.backward(retain_graph=True)
        self.context_optimizer.step()

        policy_loss = ((self.alpha * log_pi) - min_q_pi).mean()  # line 77
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()

        self.critic_optimizer.zero_grad()
        q1_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

        self.critic_optimizer.zero_grad()
        q2_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()
                           ).mean()  # E[-αlogπ(at|st)-αH]
            self.alpha_optim.zero_grad()
            alpha_loss.backward(retain_graph=True)
            self.alpha_optim.step()
            self.alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        soft_update(self.critic_target, self.critic, self.soft_target_tau)

        # save some statistics for eval
        if self.eval_statistics is None:
            # eval should set this to None.
            # this way, these statistics are only computed for one batch.
            self.eval_statistics = OrderedDict()
            if self.use_information_bottleneck:
                z_mean = np.mean(np.abs(ptu.get_numpy(self.agent.z_means[0])))
                z_sig = np.mean(ptu.get_numpy(self.agent.z_vars[0]))
                self.eval_statistics['Z mean train'] = z_mean
                self.eval_statistics['Z variance train'] = z_sig
                self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div)
                self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss)

            self.eval_statistics['Q1 Loss'] = np.mean(ptu.get_numpy(q1_loss))
            self.eval_statistics['Q2 Loss'] = np.mean(ptu.get_numpy(q2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            if self.automatic_entropy_tuning:
                self.eval_statistics['Alpha Loss'] = np.mean(
                    ptu.get_numpy(alpha_loss))
                self.eval_statistics['Alpha'] = np.mean(
                    ptu.get_numpy(self.alpha))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
예제 #3
0
    def _take_step(self, indices, context):
        # print("alpha:",self.alpha)
        num_tasks = len(indices)

        # data is (task, batch, feat)
        obs, actions, rewards, next_obs, terms = self.sample_sac(indices)

        # run inference in networks
        policy_outputs, task_z = self.agent(obs, context)
        _, policy_mean, policy_log_std, _ = policy_outputs[:4]

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
        # scale rewards for Bellman update
        rewards_flat = rewards_flat * self.reward_scale
        terms_flat = terms.view(self.batch_size * num_tasks, -1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.agent.policy.sample(torch.cat([next_obs, task_z], 1))#策略网络前向传播,获得下一个状态下的动作以及下一个状态下采取的动作
            target_qf1 , target_qf2 = self.target_critic(next_obs, next_state_action, task_z.detach())#计算目标Q值Q(st+1,at+1)获得目标值
            min_target_qf = torch.min(target_qf1,target_qf2) - self.alpha * next_state_log_pi#计算小的Q值,并加上熵
            next_q_value = rewards_flat + (1. - terms_flat) * self.discount * (min_target_qf)#Bellman update
        #利用Q和targetQ计算出Q的Loss
        qf1, qf2 = self.critic(obs, actions, task_z)
        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)
        #计算policy的loss
        pi, log_pi, _ = self.agent.policy.sample(torch.cat([obs, task_z],1))#a,logπ(s)
        qf1_pi, qf2_pi = self.critic(obs, pi, task_z)#Q(s,a)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)#min Q(s,a)
        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()#E[αlogπ(at|st)-Q(st,at)]
        #不明觉厉
        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean ** 2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std ** 2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value ** 2).sum(dim=1).mean()
        )
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        # KL constraint on z if probabilistic
        self.context_optimizer.zero_grad()
        if self.use_information_bottleneck:
            kl_div = self.agent.compute_kl_div()
            kl_loss = self.kl_lambda * kl_div
            kl_loss.backward(retain_graph=True)
        self.context_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()

        self.critic_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.critic_optimizer.step()
        self.critic_optimizer.zero_grad()
        qf2_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()  # E[-αlogπ(at|st)-αH]
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward(retain_graph=True)
            self.alpha_optimizer.step()#log alpha更新了
            self.alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            self.alpha = 1

        # target update
        soft_update(self.target_critic, self.critic, self.soft_target_tau)

        # policy update
        # n.b. policy update includes dQ/da



        # save some statistics for eval
        if self.eval_statistics is None:
            # eval should set this to None.
            # this way, these statistics are only computed for one batch.
            self.eval_statistics = OrderedDict()
            if self.use_information_bottleneck:
                z_mean = np.mean(np.abs(ptu.get_numpy(self.agent.z_means[0])))
                z_sig = np.mean(ptu.get_numpy(self.agent.z_vars[0]))
                self.eval_statistics['Z mean train'] = z_mean
                self.eval_statistics['Z variance train'] = z_sig
                self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div)
                self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss)

            if self.automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = np.mean(ptu.get_numpy(self.alpha))
                self.eval_statistics['Alpha Loss'] = np.mean(ptu.get_numpy(alpha_loss))
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(qf1),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(qf2),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
예제 #4
0
    def update_parameters(self, memory, batch_size, updates):
        """
        Computes loss and updates parameters of objective functions (Q functions, policy and alpha).
        
        ## Input:  
        
        - **memory**: instance of class ReplayMemory  
        - **batch_size** *(int)*: batch size that shall be sampled from memory
        - **updates**: indicates the number of the update steps already done 
        
        ## Output:  
        
        - **qf1_loss.item()**: loss of first q function 
        - **qf2_loss.item()**: loss of second q function
        - **policy_loss.item()**: loss of policy
        - **alpha_loss.item()**: loss of alpha
        - **alpha_tlogs.item()**: alpha tlogs (For TensorboardX logs)
        """
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(DEVICE)
        next_state_batch = torch.FloatTensor(next_state_batch).to(DEVICE)
        action_batch = torch.FloatTensor(action_batch).to(DEVICE)
        reward_batch = torch.FloatTensor(reward_batch).to(DEVICE).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(DEVICE).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)

        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_temperature_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(DEVICE)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()