def update_parameters(self, memory, batch_size, updates): # Sample a batch from memory state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size) state_batch = torch.FloatTensor(state_batch).to(self.device) next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) action_batch = torch.FloatTensor(action_batch).to(self.device) reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1) mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1) with torch.no_grad(): next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch) qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target) qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] pi, log_pi, _ = self.policy.sample(state_batch) qf1_pi, qf2_pi = self.critic(state_batch, pi) min_qf_pi = torch.min(qf1_pi, qf2_pi) policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] self.critic_optim.zero_grad() qf1_loss.backward() self.critic_optim.step() self.critic_optim.zero_grad() qf2_loss.backward() self.critic_optim.step() self.policy_optim.zero_grad() policy_loss.backward() self.policy_optim.step() if self.automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() self.alpha = self.log_alpha.exp() alpha_tlogs = self.alpha.clone() # For TensorboardX logs else: alpha_loss = torch.tensor(0.).to(self.device) alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs if updates % self.target_update_interval == 0: soft_update(self.critic_target, self.critic, self.tau) return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()
def _take_step(self, indices, context): num_tasks = len(indices) # data is (task, batch, feat) obs, actions, rewards, next_obs, terms = self.sample_data( indices) #从Replay Buffer采集数据,s,a,r,s',d # run inference in networks # policy_outputs, task_z = self.agent(obs, context)#策略forward的输出,以及任务隐变量Z policy_outputs, task_z = self.agent(obs, context) # 策略forward的输出,以及任务隐变量Z new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[: 4] #下一个状态下策略所采取的动作,其log概率 line 63 # flattens out the task dimension: t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) rewards_flat = rewards.view(self.batch_size * num_tasks, -1) # scale rewards for Bellman update rewards_flat = rewards_flat * self.reward_scale terms_flat = terms.view(self.batch_size * num_tasks, -1) with torch.no_grad(): q1_next_target, q2_next_target = self.critic_target( next_obs, new_actions, task_z) # target q line 64 min_qf_next_target = torch.min( q1_next_target, q2_next_target ) - self.alpha * log_pi #计算较小的target Q line 65 next_q_value = rewards_flat + ( 1. - terms_flat ) * self.discount * min_qf_next_target #q=r+(1-d)γ(Vst+1) line 66 q1, q2 = self.critic( obs, actions, task_z ) # forward line 68 q1_loss = F.mse_loss( q1, next_q_value ) #JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] line 69 q2_loss = F.mse_loss( q2, next_q_value ) # line 70 #pi, log_pi, _ = self.agent.policy.sample(obs) # 动作,动作的对数概率 #print(obs.size()) #[1024,27] #print(task_z.size()) in_policy = torch.cat([obs, task_z], 1) pi, _, _, log_pi, _, _, _, _, = self.agent.policy( in_policy ) # line 72 q1_pi, q2_pi = self.critic( obs, pi, task_z.detach() ) # 动作的Q值 line 74 min_q_pi = torch.min( q1_pi, q2_pi ) # line 75 # KL constraint on z if probabilistic self.context_optimizer.zero_grad() if self.use_information_bottleneck: kl_div = self.agent.compute_kl_div() kl_loss = self.kl_lambda * kl_div kl_loss.backward(retain_graph=True) self.context_optimizer.step() policy_loss = ((self.alpha * log_pi) - min_q_pi).mean() # line 77 self.policy_optimizer.zero_grad() policy_loss.backward(retain_graph=True) self.policy_optimizer.step() self.critic_optimizer.zero_grad() q1_loss.backward(retain_graph=True) self.critic_optimizer.step() self.critic_optimizer.zero_grad() q2_loss.backward(retain_graph=True) self.critic_optimizer.step() if self.automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach() ).mean() # E[-αlogπ(at|st)-αH] self.alpha_optim.zero_grad() alpha_loss.backward(retain_graph=True) self.alpha_optim.step() self.alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 soft_update(self.critic_target, self.critic, self.soft_target_tau) # save some statistics for eval if self.eval_statistics is None: # eval should set this to None. # this way, these statistics are only computed for one batch. self.eval_statistics = OrderedDict() if self.use_information_bottleneck: z_mean = np.mean(np.abs(ptu.get_numpy(self.agent.z_means[0]))) z_sig = np.mean(ptu.get_numpy(self.agent.z_vars[0])) self.eval_statistics['Z mean train'] = z_mean self.eval_statistics['Z variance train'] = z_sig self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div) self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss) self.eval_statistics['Q1 Loss'] = np.mean(ptu.get_numpy(q1_loss)) self.eval_statistics['Q2 Loss'] = np.mean(ptu.get_numpy(q2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) if self.automatic_entropy_tuning: self.eval_statistics['Alpha Loss'] = np.mean( ptu.get_numpy(alpha_loss)) self.eval_statistics['Alpha'] = np.mean( ptu.get_numpy(self.alpha)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), ))
def _take_step(self, indices, context): # print("alpha:",self.alpha) num_tasks = len(indices) # data is (task, batch, feat) obs, actions, rewards, next_obs, terms = self.sample_sac(indices) # run inference in networks policy_outputs, task_z = self.agent(obs, context) _, policy_mean, policy_log_std, _ = policy_outputs[:4] # flattens out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) rewards_flat = rewards.view(self.batch_size * num_tasks, -1) # scale rewards for Bellman update rewards_flat = rewards_flat * self.reward_scale terms_flat = terms.view(self.batch_size * num_tasks, -1) with torch.no_grad(): next_state_action, next_state_log_pi, _ = self.agent.policy.sample(torch.cat([next_obs, task_z], 1))#策略网络前向传播,获得下一个状态下的动作以及下一个状态下采取的动作 target_qf1 , target_qf2 = self.target_critic(next_obs, next_state_action, task_z.detach())#计算目标Q值Q(st+1,at+1)获得目标值 min_target_qf = torch.min(target_qf1,target_qf2) - self.alpha * next_state_log_pi#计算小的Q值,并加上熵 next_q_value = rewards_flat + (1. - terms_flat) * self.discount * (min_target_qf)#Bellman update #利用Q和targetQ计算出Q的Loss qf1, qf2 = self.critic(obs, actions, task_z) qf1_loss = F.mse_loss(qf1, next_q_value) qf2_loss = F.mse_loss(qf2, next_q_value) #计算policy的loss pi, log_pi, _ = self.agent.policy.sample(torch.cat([obs, task_z],1))#a,logπ(s) qf1_pi, qf2_pi = self.critic(obs, pi, task_z)#Q(s,a) min_qf_pi = torch.min(qf1_pi, qf2_pi)#min Q(s,a) policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()#E[αlogπ(at|st)-Q(st,at)] #不明觉厉 mean_reg_loss = self.policy_mean_reg_weight * (policy_mean ** 2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std ** 2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value ** 2).sum(dim=1).mean() ) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss # KL constraint on z if probabilistic self.context_optimizer.zero_grad() if self.use_information_bottleneck: kl_div = self.agent.compute_kl_div() kl_loss = self.kl_lambda * kl_div kl_loss.backward(retain_graph=True) self.context_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward(retain_graph=True) self.policy_optimizer.step() self.critic_optimizer.zero_grad() qf1_loss.backward(retain_graph=True) self.critic_optimizer.step() self.critic_optimizer.zero_grad() qf2_loss.backward(retain_graph=True) self.critic_optimizer.step() if self.automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() # E[-αlogπ(at|st)-αH] self.alpha_optimizer.zero_grad() alpha_loss.backward(retain_graph=True) self.alpha_optimizer.step()#log alpha更新了 self.alpha = self.log_alpha.exp() else: alpha_loss = 0 self.alpha = 1 # target update soft_update(self.target_critic, self.critic, self.soft_target_tau) # policy update # n.b. policy update includes dQ/da # save some statistics for eval if self.eval_statistics is None: # eval should set this to None. # this way, these statistics are only computed for one batch. self.eval_statistics = OrderedDict() if self.use_information_bottleneck: z_mean = np.mean(np.abs(ptu.get_numpy(self.agent.z_means[0]))) z_sig = np.mean(ptu.get_numpy(self.agent.z_vars[0])) self.eval_statistics['Z mean train'] = z_mean self.eval_statistics['Z variance train'] = z_sig self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div) self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss) if self.automatic_entropy_tuning: self.eval_statistics['Alpha'] = np.mean(ptu.get_numpy(self.alpha)) self.eval_statistics['Alpha Loss'] = np.mean(ptu.get_numpy(alpha_loss)) self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(qf1), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(qf2), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), ))
def update_parameters(self, memory, batch_size, updates): """ Computes loss and updates parameters of objective functions (Q functions, policy and alpha). ## Input: - **memory**: instance of class ReplayMemory - **batch_size** *(int)*: batch size that shall be sampled from memory - **updates**: indicates the number of the update steps already done ## Output: - **qf1_loss.item()**: loss of first q function - **qf2_loss.item()**: loss of second q function - **policy_loss.item()**: loss of policy - **alpha_loss.item()**: loss of alpha - **alpha_tlogs.item()**: alpha tlogs (For TensorboardX logs) """ # Sample a batch from memory state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample( batch_size=batch_size) state_batch = torch.FloatTensor(state_batch).to(DEVICE) next_state_batch = torch.FloatTensor(next_state_batch).to(DEVICE) action_batch = torch.FloatTensor(action_batch).to(DEVICE) reward_batch = torch.FloatTensor(reward_batch).to(DEVICE).unsqueeze(1) mask_batch = torch.FloatTensor(mask_batch).to(DEVICE).unsqueeze(1) with torch.no_grad(): next_state_action, next_state_log_pi, _ = self.policy.sample( next_state_batch) qf1_next_target, qf2_next_target = self.critic_target( next_state_batch, next_state_action) min_qf_next_target = torch.min( qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi next_q_value = reward_batch + mask_batch * self.gamma * ( min_qf_next_target) qf1, qf2 = self.critic( state_batch, action_batch ) # Two Q-functions to mitigate positive bias in the policy improvement step qf1_loss = F.mse_loss( qf1, next_q_value ) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] qf2_loss = F.mse_loss( qf2, next_q_value ) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] pi, log_pi, _ = self.policy.sample(state_batch) qf1_pi, qf2_pi = self.critic(state_batch, pi) min_qf_pi = torch.min(qf1_pi, qf2_pi) policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean( ) # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] self.critic_optim.zero_grad() qf1_loss.backward() self.critic_optim.step() self.critic_optim.zero_grad() qf2_loss.backward() self.critic_optim.step() self.policy_optim.zero_grad() policy_loss.backward() self.policy_optim.step() if self.automatic_temperature_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() self.alpha = self.log_alpha.exp() alpha_tlogs = self.alpha.clone() # For TensorboardX logs else: alpha_loss = torch.tensor(0.).to(DEVICE) alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs if updates % self.target_update_interval == 0: soft_update(self.critic_target, self.critic, self.tau) return qf1_loss.item(), qf2_loss.item(), policy_loss.item( ), alpha_loss.item(), alpha_tlogs.item()