def _update_network(self): # sample the episodes transitions = self.buffer.sample(self.args.batch_size) # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda() inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs actions_next = self.actor_target_network(inputs_next_norm_tensor) q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next) q_next_value = q_next_value.detach() target_q_value = r_tensor + self.args.gamma * q_next_value target_q_value = target_q_value.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor) critic_loss = (target_q_value - real_q_value).pow(2).mean() # the actor loss actions_real = self.actor_network(inputs_norm_tensor) actor_loss = -self.critic_network(inputs_norm_tensor, actions_real).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() # start to update the network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step() # update the critic_network self.critic_optim.zero_grad() critic_loss.backward() sync_grads(self.critic_network) self.critic_optim.step()
def _update_network(self): # sample the episodes transitions = self.buffer.sample(self.args.batch_size) # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32).reshape( transitions['r'].shape[0], -1) # if self.args.scale_rewards: # r_tensor = r_tensor/self.reward_scales # print(r_tensor.shape) if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda() inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs rep = self.actor_target_network['rep'](inputs_next_norm_tensor) actions_next = self.actor_target_network[0](rep) q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next) q_next_value = q_next_value.detach() # print('r_tensor_shape :', r_tensor.shape) # print('q_next_value :', q_next_value.shape) target_q_value = r_tensor + self.args.gamma * q_next_value target_q_value = target_q_value.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss # if self.args.ddpg_vq_version=='ver3': # real_q_value = self.critic_network.deep_forward(inputs_norm_tensor, actions_tensor) # else: real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor) if self.args.critic_loss_type == 'MSE': critic_loss = (target_q_value - real_q_value).pow(2).mean() elif self.args.critic_loss_type == 'MAE': critic_loss = (target_q_value - real_q_value).abs().mean() # if self.args.actor_loss_type=='default': # actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real)).mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # update_index = None tasks = list(range(self.env.num_reward)) # elif self.args.actor_loss_type=='min': # actor_loss = -((self.critic_network(inputs_norm_tensor, actions_real)).detach().cpu().numpy()/self.reward_scales).min(axis=1)[0].mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # update_index = None # elif self.args.actor_loss_type=='batch_min': # update_index = np.argmin((self.critic_network(inputs_norm_tensor, actions_real)).detach().cpu().numpy().mean(axis=0)/self.reward_scales) # actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # elif self.args.actor_loss_type=='softmin': # actor_loss = -(self.critic_network.softmin_forward(inputs_norm_tensor, actions_real)).min(axis=1)[0].mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # elif self.args.actor_loss_type=='strict_random': # update_index = np.random.choice(self.env.num_reward) # actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean() # # print(actor_loss) # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # elif self.args.actor_loss_type=='random': # update_index_sampling_prob= F.softmin((self.critic_network(inputs_norm_tensor, actions_real).cpu().mean(axis=0))\ # /(self.args.softmax_temperature*torch.Tensor(self.reward_scales).cpu()), dim=0).detach().cpu().numpy() # update_index = np.random.choice(self.env.num_reward, p= update_index_sampling_prob) # actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # This is MGDA loss_data = {} grads = {} scale = {} mask = None masks = {} for t_num, t in enumerate(tasks): # Comptue gradients of each loss function wrt parameters self.actor_optim.zero_grad() # rep, mask = model['rep'](images, mask) # out_t, masks[t] = model[t](rep, None) # loss = loss_fn[t](out_t, labels[t]) rep = self.actor_network['rep'](inputs_norm_tensor) actions_real = self.actor_network[t_num](rep) loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:, t_num].mean() loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() # print(loss) loss_data[t] = loss.data.item() loss.backward() grads[t] = [] for param in self.actor_network['rep'].parameters(): if param.grad is not None: # with torch.no_grad: grads[t].append( Variable(param.grad.data.clone(), requires_grad=False)) rep.detach_() actions_real.detach_() # Normalize all gradients, this is optional and not included in the paper. # print(grads) gn = gradient_normalizers(grads, loss_data, self.args.normalization_type) for t in tasks: for gr_i in range(len(grads[t])): grads[t][gr_i] = grads[t][gr_i] / gn[t] # Frank-Wolfe iteration to compute scales. sol, min_norm = MinNormSolver.find_min_norm_element( [grads[t] for t in tasks]) for i, t in enumerate(tasks): scale[t] = float(sol[i]) # print(scale) # Scaled back-propagation # actions_real_2 = self.actor_network(inputs_norm_tensor) self.actor_optim.zero_grad() # rep, _ = model['rep'](images, mask) each_loss = [] rep = self.actor_network['rep'](inputs_norm_tensor) for i, t in enumerate(tasks): actions_real = self.actor_network[t_num](rep) loss_t = -(self.critic_network(inputs_norm_tensor, actions_real))[:, i].mean() # print(loss_t) loss_t += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() each_loss.append(loss_t) loss_data[t] = loss_t.data.item() if i > 0: loss = loss + scale[t] * loss_t else: loss = scale[t] * loss_t loss.backward() self.actor_optim.step() # print(actions_real) # scale_tensor = torch.Tensor(list(scale.values())) # if self.args.cuda: # scale_tensor = scale_tensor.cuda() # # for t_num, t in enumerate(tasks): # # out_t, _ = model[t](rep, masks[t]) # # loss_t = loss_fn[t](out_t, labels[t]) # loss_t = -(self.critic_network(inputs_norm_tensor, actions_real)).mean(axis=0) # # print(loss_t) # loss_t += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # print(loss_t) # loss_t.retain_grad() # loss_data[t] = loss_t.data.item() # if i > 0: # loss = loss + scale[t]*loss_t # else: # loss = scale[t]*loss_t # print(.is_cuda) # print(loss_t.scale_tensoris_cuda) # print(scale_tensor) # print(loss_t) # actor_loss = torch.matmul(scale_tensor,loss_t) # print(loss) # actions_real_2.detach_() # self.actor_optim.zero_grad() # actor_loss.backward() # sync_grads(self.actor_network) # self.actor_optim.step() # self.actor_optim.zero_grad() # actor_loss.backward() # sync_grads(self.actor_network) # self.actor_optim.step() # update the critic_network self.critic_optim.zero_grad() critic_loss.backward() sync_grads(self.critic_network) self.critic_optim.step() return loss, each_loss, critic_loss, scale
def _update_network(self): # sample the episodes transitions = self.buffer.sample(self.args.batch_size) # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) for i in range(self.env.num_reward): r_tensor = torch.tensor(transitions['r'], dtype=torch.float32).reshape( transitions['r'].shape[0], -1)[:, i] # print(r_tensor.shape) if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda() inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs actions_next = self.actor_target_network( inputs_next_norm_tensor) q_next_value = self.critic_target_network[i]( inputs_next_norm_tensor, actions_next) q_next_value = q_next_value.detach() # print('r_tensor_shape :', r_tensor.shape) # print('q_next_value :', q_next_value.shape) target_q_value = r_tensor + self.args.gamma * q_next_value target_q_value = target_q_value.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss # if self.args.ddpg_vq_version=='ver3': # real_q_value = self.critic_network.deep_forward(inputs_norm_tensor, actions_tensor) # else: real_q_value = self.critic_network[i](inputs_norm_tensor, actions_tensor) # print('target_q_value :', target_q_value.shape) # print('real_q_value :', real_q_value.shape) # print((target_q_value - real_q_value).shape) # print(( (target_q_value - real_q_value).pow(2)).shape) # print((target_q_value - real_q_value).pow(2).mean().shape) if self.args.critic_loss_type == 'MSE': critic_loss = (target_q_value - real_q_value).pow(2).mean() # elif self.args.critic_loss_type=='max': # critic_loss, _ = torch.max((target_q_value - real_q_value).pow(2),dim=1) # critic_loss = torch.mean(critic_loss) elif self.args.critic_loss_type == 'MAE': critic_loss = (target_q_value - real_q_value).abs().mean() # print(critic_loss.shape) # .mean() # update the critic_network self.critic_optim[i].zero_grad() critic_loss.backward() sync_grads(self.critic_network[i]) self.critic_optim[i].step() # print('critic_loss :',critic_loss.shape) # the actor loss actions_real = self.actor_network(inputs_norm_tensor) update_index_sampling_prob = [] for i in range(self.env.num_reward): update_index_sampling_prob.append(self.critic_network[i]( inputs_norm_tensor, actions_real).mean().data.cpu().numpy().item()) update_index_sampling_prob = torch.Tensor( np.array(update_index_sampling_prob)) update_index_sampling_prob = torch.nn.Softmin(dim=0)( update_index_sampling_prob / self.args.softmax_temperature).numpy() if self.args.actor_loss_type == 'default': actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real)).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() elif self.args.actor_loss_type == 'min': actor_loss = -(self.critic_network( inputs_norm_tensor, actions_real)).min(axis=1)[0].mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() elif self.args.actor_loss_type == 'softmin': # print((self.critic_network.softmin_forward(inputs_norm_tensor, actions_real)).shape) actor_loss = -(self.critic_network.softmin_forward( inputs_norm_tensor, actions_real)).mean() # print(actor_loss.shape) actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() elif self.args.actor_loss_type == 'random': # print((self.critic_network.softmin_forward(inputs_norm_tensor, actions_real)).shape) update_index = np.random.choice(self.env.num_reward, p=update_index_sampling_prob) actor_loss = -(self.critic_network[update_index]( inputs_norm_tensor, actions_real)).mean() # print(actor_loss) actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() # print(actor_loss) # for i in range(self.env.num_reward): # self.writer.add_scalar('update_probability/number_{}'.format(i), update_index_sampling_prob[i]) # start to update the network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step()
def _update_network(self): # sample the episodes transitions = self.buffer.sample(self.args.batch_size) # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda(self.device) inputs_next_norm_tensor = inputs_next_norm_tensor.cuda(self.device) actions_tensor = actions_tensor.cuda(self.device) r_tensor = r_tensor.cuda(self.device) # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs actions_next = self.actor_target_network(inputs_next_norm_tensor) actions_next += self.args.noise_eps * self.env_params[ 'action_max'] * torch.randn(actions_next.shape).cuda( self.device) actions_next = torch.clamp(actions_next, -self.env_params['action_max'], self.env_params['action_max']) q_next_value1 = self.critic_target_network1( inputs_next_norm_tensor, actions_next) q_next_value2 = self.critic_target_network2( inputs_next_norm_tensor, actions_next) target_q_value = r_tensor + self.args.gamma * torch.min( q_next_value1, q_next_value2) # clip the q value clip_return = 1 / (1 - self.args.gamma) target_q_value = torch.clamp(target_q_value, -clip_return, 0) target_q_value = target_q_value.detach() # the q loss real_q_value1 = self.critic_network1(inputs_norm_tensor, actions_tensor) critic_loss1 = (target_q_value - real_q_value1).pow(2).mean() real_q_value2 = self.critic_network2(inputs_norm_tensor, actions_tensor) critic_loss2 = (target_q_value - real_q_value2).pow(2).mean() # the actor loss actions_real = self.actor_network(inputs_norm_tensor) actor_loss = -torch.min( self.critic_network1(inputs_norm_tensor, actions_real), self.critic_network2(inputs_norm_tensor, actions_real)).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() # start to update the network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step() # update the critic_network self.critic_optim1.zero_grad() critic_loss1.backward() sync_grads(self.critic_network1) self.critic_optim1.step() self.critic_optim2.zero_grad() critic_loss2.backward() sync_grads(self.critic_network2) self.critic_optim2.step() self.logger.store(LossPi=actor_loss.detach().cpu().numpy()) self.logger.store(LossQ=(critic_loss1 + critic_loss2).detach().cpu().numpy())
def _update_network(self): # sample the episodes transitions = self.buffer.sample(self.args.batch_size) # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32).reshape( transitions['r'].shape[0], -1) # if self.args.scale_rewards: # r_tensor = r_tensor/self.reward_scales # print(r_tensor.shape) if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda() inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs actions_next = self.actor_target_network(inputs_next_norm_tensor) q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next) q_next_value = q_next_value.detach() # print('r_tensor_shape :', r_tensor.shape) # print('q_next_value :', q_next_value.shape) target_q_value = r_tensor + self.args.gamma * q_next_value target_q_value = target_q_value.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss # if self.args.ddpg_vq_version=='ver3': # real_q_value = self.critic_network.deep_forward(inputs_norm_tensor, actions_tensor) # else: real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor) if self.args.critic_loss_type == 'MSE': critic_loss = (target_q_value - real_q_value).pow(2).mean() elif self.args.critic_loss_type == 'MAE': critic_loss = (target_q_value - real_q_value).abs().mean() actions_real = self.actor_network(inputs_norm_tensor) with torch.no_grad(): each_reward = (self.critic_network(inputs_norm_tensor, actions_real)).mean(axis=0) # print(each_reward.shape) if self.args.actor_loss_type == 'sum': actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real)).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() update_index = 0 elif self.args.actor_loss_type == 'smoothed_minmax': actor_loss = torch.exp( (self.args.softmax_temperature) * (-(self.critic_network(inputs_norm_tensor, actions_real))) ).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() update_index = 0 # elif self.args.actor_loss_type=='smoothed_minmax': # actor_loss = torch.exp((self.args.softmax_temperature)*(-(self.critic_network(inputs_norm_tensor, actions_real)))).mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # update_index = None # elif self.args.actor_loss_type=='min': # actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real)).min(axis=1)[0].mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # update_index = None elif self.args.actor_loss_type == 'minmax': update_index = np.argmin((self.critic_network( inputs_norm_tensor, actions_real)).detach().cpu().numpy().mean(axis=0)) # print(update_index) onehot = torch.zeros(self.env.num_reward) onehot[update_index] = 1. if self.args.cuda: onehot = onehot.cuda() actor_loss = ( -(self.critic_network(inputs_norm_tensor, actions_real)) * onehot).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() # elif self.args.actor_loss_type=='softmin': # actor_loss = -(self.critic_network.softmin_forward(inputs_norm_tensor, actions_real)).min(axis=1)[0].mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # elif self.args.actor_loss_type=='strict_random': # update_index = np.random.choice(self.env.num_reward) # actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean() # # print(actor_loss) # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # elif self.args.actor_loss_type=='random': # update_index_sampling_prob= F.softmin((self.critic_network(inputs_norm_tensor, actions_real).cpu().mean(axis=0))\ # /(self.args.softmax_temperature), dim=0).detach().cpu().numpy() # update_index = np.random.choice(self.env.num_reward, p= update_index_sampling_prob) # actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() elif self.args.actor_loss_type == 'prod': actor_loss = torch.prod( -self.critic_network(inputs_norm_tensor, actions_real), 1).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() update_index = None # start to update the network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step() # update the critic_network self.critic_optim.zero_grad() critic_loss.backward() sync_grads(self.critic_network) self.critic_optim.step() return update_index, actor_loss, each_reward, critic_loss
def _update_network(self): # sample the episodes transitions = self.buffer.sample(self.args.batch_size) # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) * self.args.reward_scale if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda(self.device) inputs_next_norm_tensor = inputs_next_norm_tensor.cuda(self.device) actions_tensor = actions_tensor.cuda(self.device) r_tensor = r_tensor.cuda(self.device) # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs actions_next, log_prob_actions_next = self.actor_network( inputs_next_norm_tensor) q_next_value1 = self.critic_target_network1( inputs_next_norm_tensor, actions_next).detach() q_next_value2 = self.critic_target_network2( inputs_next_norm_tensor, actions_next).detach() target_q_value = r_tensor + self.args.gamma * ( torch.min(q_next_value1, q_next_value2) - self.alpha * log_prob_actions_next) target_q_value = target_q_value.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss real_q_value1 = self.critic_network1(inputs_norm_tensor, actions_tensor) real_q_value2 = self.critic_network2(inputs_norm_tensor, actions_tensor) critic_loss1 = (target_q_value - real_q_value1).pow(2).mean() critic_loss2 = (target_q_value - real_q_value2).pow(2).mean() # the actor loss actions, log_prob_actions = self.actor_network(inputs_norm_tensor) log_prob_actions = log_prob_actions.mean() actor_loss = self.alpha * log_prob_actions - torch.min( self.critic_network1(inputs_norm_tensor, actions), self.critic_network2(inputs_norm_tensor, actions)).mean() # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # start to update the network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step() # update the critic_network self.critic_optim1.zero_grad() critic_loss1.backward() sync_grads(self.critic_network1) self.critic_optim1.step() self.critic_optim2.zero_grad() critic_loss2.backward() sync_grads(self.critic_network2) self.critic_optim2.step() self.logger.store(LossPi=actor_loss.detach().cpu().numpy()) self.logger.store(LossQ=(critic_loss1 + critic_loss2).detach().cpu().numpy()) self.logger.store(Entropy=-log_prob_actions.detach().cpu().numpy()) # auto temperature if self.args.alpha < 0: comm = MPI.COMM_WORLD log_prob_actions = log_prob_actions.detach().cpu().numpy() global_log_prob_actions = np.zeros_like(log_prob_actions) comm.Allreduce(log_prob_actions, global_log_prob_actions, op=MPI.SUM) global_log_prob_actions /= MPI.COMM_WORLD.Get_size() logalpha_loss = -self.log_alpha * (log_prob_actions + self.target_entropy) self.alpha_optim.zero_grad() logalpha_loss.backward() self.alpha_optim.step() with torch.no_grad(): self.alpha = self.log_alpha.exp().detach() self.logger.store(alpha=self.alpha.detach().cpu().numpy())
def _update_network(self): # sample the episodes batches = self.buffer.sample(self.args.batch_size) o = torch.FloatTensor(batches['obs']).to(self.device) o2 = torch.FloatTensor(batches['obs2']).to(self.device) a = torch.FloatTensor(batches['act']).to(self.device) r = torch.FloatTensor(batches['rew']).to(self.device) c = torch.FloatTensor(batches['cost']).to(self.device) d = torch.FloatTensor(batches['done']).to(self.device) # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs a2 = self.actor_network(o2) q_next_value1 = self.critic_target_network1(o2, a2).detach() q_next_value2 = self.critic_target_network2(o2, a2).detach() target_q_value = r + self.args.gamma * (1 - d) * torch.min( q_next_value1, q_next_value2) target_q_value = target_q_value.detach() p_next_value1 = self.advice_target_network1(o2, a2).detach() p_next_value2 = self.advice_target_network2(o2, a2).detach() target_p_value = -c + self.args.gamma * (1 - d) * torch.min( p_next_value1, p_next_value2) target_p_value = target_p_value.detach() # the q loss real_q_value1 = self.critic_network1(o, a) real_q_value2 = self.critic_network2(o, a) critic_loss1 = (target_q_value - real_q_value1).pow(2).mean() critic_loss2 = (target_q_value - real_q_value2).pow(2).mean() # the p loss real_p_value1 = self.advice_network1(o, a) real_p_value2 = self.advice_network2(o, a) advice_loss1 = (target_p_value - real_p_value1).pow(2).mean() advice_loss2 = (target_p_value - real_p_value2).pow(2).mean() # the actor loss o_exp = o.repeat(self.args.expand_batch, 1) a_exp = self.actor_network(o_exp) actor_loss = -torch.min(self.critic_network1(o_exp, a_exp), self.critic_network2(o_exp, a_exp)).mean() actor_loss -= self.args.advice * torch.min( self.advice_network1(o_exp, a_exp), self.advice_network2(o_exp, a_exp)).mean() mmd_entropy = torch.tensor(0.0) if self.args.mmd: # mmd is computationally expensive a_exp_reshape = a_exp.view(self.args.expand_batch, -1, a_exp.shape[-1]).transpose(0, 1) with torch.no_grad(): uniform_actions = (2 * torch.rand_like(a_exp_reshape) - 1) mmd_entropy = mmd(a_exp_reshape, uniform_actions) if self.args.beta_mmd <= 0.0: mmd_entropy.detach_() else: actor_loss += self.args.beta_mmd * mmd_entropy # start to update the network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step() # update the critic_network self.critic_optim1.zero_grad() critic_loss1.backward() sync_grads(self.critic_network1) self.critic_optim1.step() self.critic_optim2.zero_grad() critic_loss2.backward() sync_grads(self.critic_network2) self.critic_optim2.step() self.logger.store(LossPi=actor_loss.detach().cpu().numpy()) self.logger.store(LossQ=(critic_loss1 + critic_loss2).detach().cpu().numpy()) self.logger.store(MMDEntropy=mmd_entropy.detach().cpu().numpy())
def _update_network(self, step, if_write=False): # # for agent 1, the protagonist agent # sample the episodes transitions = self.buffer_1.sample( self.args.batch_size) # sample from the replay_buffer # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm_1.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm_1.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.to(self.args.device) inputs_next_norm_tensor = inputs_next_norm_tensor.to( self.args.device) actions_tensor = actions_tensor.to(self.args.device) r_tensor = r_tensor.to(self.args.device) # calculate the target Q value function with torch.no_grad( ): # wrapped by this, the grad_fn don't track this part # do the normalization # concatenate the stuffs actions_next, next_state_logp, _ = self.actor_target_network_1.sample( inputs_next_norm_tensor) q_next_value = self.critic_target_network_1( inputs_next_norm_tensor, actions_next) q_next_value = q_next_value.detach() target_q_value = r_tensor + self.args.gamma * q_next_value target_q_value = target_q_value.detach() q_next_value_1_2 = self.critic_target_network_1_2( inputs_next_norm_tensor, actions_next) q_next_value_1_2 = q_next_value_1_2.detach() target_q_value_1_2 = r_tensor + self.args.gamma * q_next_value_1_2 target_q_value_1_2 = target_q_value_1_2.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) # make target_q_value between -clip_return and 0 target_q_value = torch.min( target_q_value, target_q_value_1_2) - self.alpha_1 * next_state_logp target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss real_q_value = self.critic_network_1(inputs_norm_tensor, actions_tensor) critic_loss = (target_q_value - real_q_value).pow(2).mean() # the q loss real_q_value_1_2 = self.critic_network_1_2(inputs_norm_tensor, actions_tensor) critic_loss_1_2 = (target_q_value_1_2 - real_q_value_1_2).pow(2).mean() # the actor loss actions_real, logp, _ = self.actor_network_1.sample(inputs_norm_tensor) actor_loss = (self.alpha_1 * logp - torch.min( self.critic_network_1(inputs_norm_tensor, actions_real), self.critic_network_1_2(inputs_norm_tensor, actions_real))).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() alpha_loss_1 = -(self.log_alpha_1 * (logp + self.target_entropy_1).detach()).mean() self.alpha_optim_1.zero_grad() alpha_loss_1.backward() sync_grads_for_tensor(self.log_alpha_1) self.alpha_optim_1.step() self.alpha_1 = self.log_alpha_1.exp() # start to update the network self.actor_optim_1.zero_grad() actor_loss.backward() # # clip the grad # clip_grad_norm_(self.actor_network_1.parameters(), max_norm=20, norm_type=2) sync_grads(self.actor_network_1) self.actor_optim_1.step() # update the critic_network self.critic_optim_1.zero_grad() critic_loss.backward() # # clip the grad # clip_grad_norm_(self.critic_network_1.parameters(), max_norm=20, norm_type=2) sync_grads(self.critic_network_1) self.critic_optim_1.step() self.critic_optim_1_2.zero_grad() critic_loss_1_2.backward() # # clip the grad # clip_grad_norm_(self.critic_network_1_2.parameters(), max_norm=20, norm_type=2) sync_grads(self.critic_network_1_2) self.critic_optim_1_2.step() if self.writer is not None and if_write: self.writer.add_scalar('actor_loss_1', actor_loss, step) self.writer.add_scalar('critic_loss_1', critic_loss, step) """""" # # for agent 2, the adversary agent # sample the episodes transitions = self.buffer_2.sample( self.args.batch_size) # sample from the replay_buffer # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm_2.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm_2.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.to(self.args.device) inputs_next_norm_tensor = inputs_next_norm_tensor.to( self.args.device) actions_tensor = actions_tensor.to(self.args.device) r_tensor = r_tensor.to(self.args.device) # calculate the target Q value function with torch.no_grad( ): # wrapped by this, the grad_fn don't track this part # do the normalization # concatenate the stuffs actions_next, next_state_logp, _ = self.actor_target_network_2.sample( inputs_next_norm_tensor) q_next_value = self.critic_target_network_2( inputs_next_norm_tensor, actions_next) q_next_value = q_next_value.detach() target_q_value = r_tensor + self.args.gamma * q_next_value target_q_value = target_q_value.detach() q_next_value_2_2 = self.critic_target_network_2_2( inputs_next_norm_tensor, actions_next) q_next_value_2_2 = q_next_value_2_2.detach() target_q_value_2_2 = r_tensor + self.args.gamma * q_next_value_2_2 target_q_value_2_2 = target_q_value_2_2.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) # make target_q_value between -clip_return and 0 target_q_value = torch.min( target_q_value, target_q_value_2_2) - self.alpha_2 * next_state_logp target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss real_q_value = self.critic_network_2(inputs_norm_tensor, actions_tensor) critic_loss = (target_q_value - real_q_value).pow(2).mean() # the q loss real_q_value_2_2 = self.critic_network_2_2(inputs_norm_tensor, actions_tensor) critic_loss_2_2 = (target_q_value_2_2 - real_q_value_2_2).pow(2).mean() # the actor loss actions_real, logp, _ = self.actor_network_2.sample(inputs_norm_tensor) actor_loss = (self.alpha_2 * logp + torch.min( self.critic_network_2(inputs_norm_tensor, actions_real), self.critic_network_2_2(inputs_norm_tensor, actions_real)) ).mean() # for adversary agent actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() alpha_loss_2 = -(self.log_alpha_2 * (logp + self.target_entropy_2).detach()).mean() self.alpha_optim_2.zero_grad() alpha_loss_2.backward() sync_grads_for_tensor(self.log_alpha_2) self.alpha_optim_2.step() self.alpha_2 = self.log_alpha_2.exp() # start to update the network self.actor_optim_2.zero_grad() actor_loss.backward() # # clip grad loss # clip_grad_norm_(self.actor_network_2.parameters(), max_norm=20, norm_type=2) sync_grads(self.actor_network_2) self.actor_optim_2.step() # update the critic_network self.critic_optim_2.zero_grad() critic_loss.backward() # # clip grad loss # clip_grad_norm_(self.critic_network_2.parameters(), max_norm=20, norm_type=2) sync_grads(self.critic_network_2) self.critic_optim_2.step() self.critic_optim_2_2.zero_grad() critic_loss_2_2.backward() # # clip grad loss # clip_grad_norm_(self.critic_network_2_2.parameters(), max_norm=20, norm_type=2) sync_grads(self.critic_network_2_2) self.critic_optim_2_2.step() if self.writer is not None and if_write: self.writer.add_scalar('actor_loss_2', actor_loss, step) self.writer.add_scalar('critic_loss_2', critic_loss, step)
def _update_network(self): # sample the episodes transitions = self.buffer.sample(self.args.batch_size) # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) * self.args.reward_scale if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda(self.device) inputs_next_norm_tensor = inputs_next_norm_tensor.cuda(self.device) actions_tensor = actions_tensor.cuda(self.device) r_tensor = r_tensor.cuda(self.device) # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs actions_next = self.actor_network(inputs_next_norm_tensor) q_next_value1 = self.critic_target_network1( inputs_next_norm_tensor, actions_next).detach() q_next_value2 = self.critic_target_network2( inputs_next_norm_tensor, actions_next).detach() target_q_value = r_tensor + self.args.gamma * torch.min( q_next_value1, q_next_value2) target_q_value = target_q_value.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss real_q_value1 = self.critic_network1(inputs_norm_tensor, actions_tensor) real_q_value2 = self.critic_network2(inputs_norm_tensor, actions_tensor) critic_loss1 = (target_q_value - real_q_value1).pow(2).mean() critic_loss2 = (target_q_value - real_q_value2).pow(2).mean() # the actor loss exp_inputs_norm_tensor = inputs_norm_tensor.repeat( self.args.expand_batch, 1) exp_actions_real = self.actor_network(exp_inputs_norm_tensor) actor_loss = -torch.min( self.critic_network1(exp_inputs_norm_tensor, exp_actions_real), self.critic_network2(exp_inputs_norm_tensor, exp_actions_real)).mean() mmd_entropy = torch.tensor(0.0) if self.args.mmd: # mmd is computationally expensive exp_actions_real2 = exp_actions_real.view( self.args.expand_batch, -1, exp_actions_real.shape[-1]).transpose(0, 1) with torch.no_grad(): uniform_actions = (2 * torch.rand_like(exp_actions_real2) - 1) mmd_entropy = mmd(exp_actions_real2, uniform_actions) if self.args.beta_mmd <= 0.0: mmd_entropy.detach_() else: actor_loss += self.args.beta_mmd * mmd_entropy # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() # start to update the network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step() # update the critic_network self.critic_optim1.zero_grad() critic_loss1.backward() sync_grads(self.critic_network1) self.critic_optim1.step() self.critic_optim2.zero_grad() critic_loss2.backward() sync_grads(self.critic_network2) self.critic_optim2.step() self.logger.store(LossPi=actor_loss.detach().cpu().numpy()) self.logger.store(LossQ=(critic_loss1 + critic_loss2).detach().cpu().numpy()) self.logger.store(MMDEntropy=mmd_entropy.detach().cpu().numpy())
def update_disentangled(actor_network, critic_network, critic_target_network, configuration_network, policy_optim, critic_optim, alpha, log_alpha, target_entropy, alpha_optim, obs_norm, ag_norm, g_norm, obs_next_norm, ag_next_norm, g_next_norm, actions, rewards, args): obs_norm_tensor = torch.tensor(obs_norm, dtype=torch.float32) obs_next_norm_tensor = torch.tensor(obs_next_norm, dtype=torch.float32) ag_norm_tensor = torch.tensor(ag_norm, dtype=torch.float32) ag_next_norm_tensor = torch.tensor(ag_next_norm, dtype=torch.float32) g_norm_tensor = torch.tensor(g_norm, dtype=torch.float32) g_next_norm_tensor = torch.tensor(g_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(actions, dtype=torch.float32) r_tensor = torch.tensor(rewards, dtype=torch.float32).reshape(rewards.shape[0], 1) if args.cuda: #inputs_norm_tensor = inputs_norm_tensor.cuda() #inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() config_ag_z = configuration_network(ag_norm_tensor) config_ag_z_next = configuration_network(ag_next_norm_tensor) config_g_z = configuration_network(g_norm_tensor) config_g_z_next = configuration_network(g_next_norm_tensor) with torch.no_grad(): # do the normalization # concatenate the stuffs inputs_norm_tensor = torch.tensor(np.concatenate( [obs_norm, config_ag_z.detach(), config_g_z.detach()], axis=1), dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(np.concatenate([ obs_next_norm, config_ag_z_next.detach(), config_g_z_next.detach() ], axis=1), dtype=torch.float32) actions_next, log_pi_next, _ = actor_network.sample( inputs_next_norm_tensor) qf1_next_target, qf2_next_target = critic_target_network( obs_next_norm_tensor, actions_next, config_ag_z_next.detach(), config_g_z_next.detach()) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * log_pi_next next_q_value = r_tensor + args.gamma * min_qf_next_target # clip the q value """clip_return = 1 / (1 - args.gamma) next_q_value = torch.clamp(next_q_value, 0, clip_return)""" # the q loss qf1, qf2 = critic_network(obs_norm_tensor, actions_tensor, config_ag_z, config_g_z) qf1_loss = F.mse_loss(qf1, next_q_value) qf2_loss = F.mse_loss(qf2, next_q_value) # the actor loss pi, log_pi, _ = actor_network.sample(inputs_norm_tensor) qf1_pi, qf2_pi = critic_network(obs_norm_tensor, pi, config_ag_z, config_g_z) min_qf_pi = torch.min(qf1_pi, qf2_pi) policy_loss = ((alpha * log_pi) - min_qf_pi).mean() # start to update the network policy_optim.zero_grad() policy_loss.backward(retain_graph=True) sync_grads(actor_network) policy_optim.step() # update the critic_network configuration_network.zero_grad() critic_optim.zero_grad() qf1_loss.backward(retain_graph=True) sync_grads(critic_network) critic_optim.step() critic_optim.zero_grad() qf2_loss.backward() sync_grads(critic_network) critic_optim.step() # configuration_optim.step() sync_grads(configuration_network) # configuration_optim.step() alpha_loss, alpha_tlogs = update_entropy(alpha, log_alpha, target_entropy, log_pi, alpha_optim, args) return qf1_loss.item(), qf2_loss.item(), policy_loss.item( ), alpha_loss.item(), alpha_tlogs.item()
def update_flat(actor_network, critic_network, critic_target_network, policy_optim, critic_optim, alpha, log_alpha, target_entropy, alpha_optim, obs_norm, ag_norm, g_norm, obs_next_norm, actions, rewards, args): inputs_norm = np.concatenate([obs_norm, ag_norm, g_norm], axis=1) inputs_next_norm = np.concatenate([obs_next_norm, ag_norm, g_norm], axis=1) # Tensorize inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(actions, dtype=torch.float32) r_tensor = torch.tensor(rewards, dtype=torch.float32).reshape(rewards.shape[0], 1) if args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda() inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() with torch.no_grad(): actions_next, log_pi_next, _ = actor_network.sample( inputs_next_norm_tensor) qf1_next_target, qf2_next_target = critic_target_network( inputs_next_norm_tensor, actions_next) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * log_pi_next next_q_value = r_tensor + args.gamma * min_qf_next_target # the q loss qf1, qf2 = critic_network(inputs_norm_tensor, actions_tensor) qf1_loss = F.mse_loss(qf1, next_q_value) qf2_loss = F.mse_loss(qf2, next_q_value) # the actor loss pi, log_pi, _ = actor_network.sample(inputs_norm_tensor) qf1_pi, qf2_pi = critic_network(inputs_norm_tensor, pi) min_qf_pi = torch.min(qf1_pi, qf2_pi) policy_loss = ((alpha * log_pi) - min_qf_pi).mean() # start to update the network policy_optim.zero_grad() policy_loss.backward() sync_grads(actor_network) policy_optim.step() # update the critic_network critic_optim.zero_grad() qf1_loss.backward() sync_grads(critic_network) critic_optim.step() critic_optim.zero_grad() qf2_loss.backward() sync_grads(critic_network) critic_optim.step() alpha_loss, alpha_tlogs = update_entropy(alpha, log_alpha, target_entropy, log_pi, alpha_optim, args) return qf1_loss.item(), qf2_loss.item(), policy_loss.item( ), alpha_loss.item(), alpha_tlogs.item()
def update_deepsets(model, language, policy_optim, critic_optim, alpha, log_alpha, target_entropy, alpha_optim, obs_norm, ag_norm, g_norm, obs_next_norm, ag_next_norm, anchor_g, actions, rewards, language_goals, args): # Tensorize obs_norm_tensor = torch.tensor(obs_norm, dtype=torch.float32) obs_next_norm_tensor = torch.tensor(obs_next_norm, dtype=torch.float32) if language: g_norm_tensor = g_norm else: g_norm_tensor = torch.tensor(g_norm, dtype=torch.float32) ag_norm_tensor = torch.tensor(ag_norm, dtype=torch.float32) ag_next_norm_tensor = torch.tensor(ag_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(actions, dtype=torch.float32) r_tensor = torch.tensor(rewards, dtype=torch.float32).reshape(rewards.shape[0], 1) anchor_g_tensor = torch.tensor(anchor_g, dtype=torch.float32) if args.cuda: obs_norm_tensor = obs_norm_tensor.cuda() obs_next_norm_tensor = obs_next_norm_tensor.cuda() g_norm_tensor = g_norm_tensor.cuda() ag_norm_tensor = ag_norm_tensor.cuda() ag_next_norm_tensor = ag_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() with torch.no_grad(): if args.algo == 'language': model.forward_pass(obs_next_norm_tensor, language_goals=language_goals) elif args.algo == 'continuous': model.forward_pass(obs_next_norm_tensor, ag_next_norm_tensor, g_norm_tensor) else: model.forward_pass(obs_next_norm_tensor, ag_next_norm_tensor, g_norm_tensor, anchor_g_tensor) actions_next, log_pi_next = model.pi_tensor, model.log_prob qf1_next_target, qf2_next_target = model.target_q1_pi_tensor, model.target_q2_pi_tensor min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * log_pi_next next_q_value = r_tensor + args.gamma * min_qf_next_target # the q loss if args.algo == 'language': qf1, qf2 = model.forward_pass(obs_norm_tensor, actions=actions_tensor, language_goals=language_goals) elif args.algo == 'continuous': qf1, qf2 = model.forward_pass(obs_norm_tensor, ag_norm_tensor, g_norm_tensor, actions=actions_tensor) else: qf1, qf2 = model.forward_pass(obs_norm_tensor, ag_norm_tensor, g_norm_tensor, anchor_g_tensor, actions=actions_tensor) qf1_loss = F.mse_loss(qf1, next_q_value) qf2_loss = F.mse_loss(qf2, next_q_value) qf_loss = qf1_loss + qf2_loss # the actor loss pi, log_pi = model.pi_tensor, model.log_prob qf1_pi, qf2_pi = model.q1_pi_tensor, model.q2_pi_tensor min_qf_pi = torch.min(qf1_pi, qf2_pi) policy_loss = ((alpha * log_pi) - min_qf_pi).mean() # start to update the network policy_optim.zero_grad() policy_loss.backward(retain_graph=True) sync_grads(model.actor) policy_optim.step() # update the critic_network critic_optim.zero_grad() qf_loss.backward() sync_grads(model.critic) critic_optim.step() alpha_loss, alpha_tlogs = update_entropy(alpha, log_alpha, target_entropy, log_pi, alpha_optim, args) return qf1_loss.item(), qf2_loss.item(), policy_loss.item( ), alpha_loss.item(), alpha_tlogs.item()