def ul_optimize(self, itr, rl_samples=None): opt_info_ul = OptInfoUl(*([] for _ in range(len(OptInfoUl._fields)))) n_ul_updates = self.compute_ul_update_schedule(itr) ul_bs = self.ul_batch_size n_rl_samples = (0 if rl_samples is None else len( rl_samples.agent_inputs.observation)) for i in range(n_ul_updates): self.ul_update_counter += 1 if self.ul_lr_scheduler is not None: self.ul_lr_scheduler.step(self.ul_update_counter) if n_rl_samples >= self.ul_batch_size * (i + 1): ul_samples = rl_samples[i * ul_bs:(i + 1) * ul_bs] else: ul_samples = None ul_loss, ul_accuracy, grad_norm = self.ul_optimize_one_step( ul_samples) opt_info_ul.ulLoss.append(ul_loss.item()) opt_info_ul.ulAccuracy.append(ul_accuracy.item()) opt_info_ul.ulGradNorm.append(grad_norm.item()) if self.ul_update_counter % self.ul_target_update_interval == 0: update_state_dict( self.ul_target_encoder, self.ul_encoder.state_dict(), self.ul_target_update_tau, ) opt_info_ul.ulUpdates.append(self.ul_update_counter) return opt_info_ul
def optimize(self, itr): opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) samples = self.replay_buffer.sample_batch(self.batch_size) if self.lr_scheduler is not None: self.lr_scheduler.step(itr) # Do every itr instead of every epoch self.optimizer.zero_grad() stdim_loss, loss_vals, accuracies, conv_output = self.stdim_loss( samples) act_loss = self.activation_loss(conv_output) loss = stdim_loss + act_loss loss.backward() if self.clip_grad_norm is None: grad_norm = 0.0 else: grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info.stdimLoss.append(stdim_loss.item()) opt_info.ggLoss.append(loss_vals[0]) opt_info.glLoss.append(loss_vals[1]) opt_info.llLoss.append(loss_vals[2]) opt_info.ggAccuracy.append(accuracies[0]) opt_info.glAccuracy.append(accuracies[1]) opt_info.llAccuracy.append(accuracies[2]) opt_info.activationLoss.append(act_loss.item()) opt_info.gradNorm.append(grad_norm.item()) opt_info.convActivation.append( conv_output[0].detach().cpu().view(-1).numpy()) # Keep 1 full one. if itr % self.target_update_interval == 0: update_state_dict(self.target_encoder, self.encoder.state_dict(), self.target_update_tau) return opt_info
def update_target(self, tau=1): """ 更新target network,即把main network(self.model)的参数拷贝到target network(self.target_model)上。 当τ>0的时候会使用soft update算法来更新参数。 为了保持learning过程的稳定性以及高效性,target network不是实时更新,而是周期性地更新一次。例如,在DQN.optimize_agent()函数中, 会看到每隔一定的周期才会调用一次本函数的代码逻辑。 :param tau: soft update算法里的τ参数。 """ update_state_dict(self.target_model, self.model.state_dict(), tau)
def ul_optimize(self, itr): opt_info_ul = OptInfoUl(*([] for _ in range(len(OptInfoUl._fields)))) n_ul_updates = self.compute_ul_update_schedule(itr) for _ in range(n_ul_updates): self.ul_update_counter += 1 if self.ul_lr_scheduler is not None: self.ul_lr_scheduler.step(self.ul_update_counter) ul_loss, ul_accuracy, grad_norm = self.ul_optimize_one_step() opt_info_ul.ulLoss.append(ul_loss.item()) opt_info_ul.ulAccuracy.append(ul_accuracy.item()) opt_info_ul.ulGradNorm.append(grad_norm.item()) if self.ul_update_counter % self.ul_target_update_interval == 0: update_state_dict(self.ul_target_encoder, self.ul_encoder.state_dict(), self.ul_target_update_tau) opt_info_ul.ulUpdates.append(self.ul_update_counter) return opt_info_ul
def update_targets(self, q_tau=1, encoder_tau=1): """Do each parameter ONLY ONCE.""" update_state_dict(self.target_conv, self.conv.state_dict(), encoder_tau) update_state_dict(self.target_q_fc1, self.q_fc1.state_dict(), encoder_tau) update_state_dict(self.target_q_mlps, self.q_mlps.state_dict(), q_tau)
def update_target(self, tau=1): update_state_dict(self.target_model, self.model.state_dict(), tau) update_state_dict(self.target_q_model, self.q_model.state_dict(), tau)
def update_target(self, tau=1): update_state_dict(self.target_v_model, self.v_model.state_dict(), tau)
def update_target(self, tau=1): super().update_target(tau) update_state_dict(self.target_q2_model, self.q2_model.state_dict(), tau)
def update_target(self, tau=1): """Copies the model parameters into the target model.""" update_state_dict(self.target_model, self.model.state_dict(), tau)
def update_target(self, tau=1): [ update_state_dict(target_q, q.state_dict(), tau) for target_q, q in zip(self.target_q_models, self.q_models) ]
def do_spr_loss(self, pred_latents, observation): # pred_latents.shape = [6, 32, 64, 7, 7] or [6, 32, 600] # observation.shape = [16, 32, 4, 1, 84, 84] pred_latents = torch.stack(pred_latents, 1) latents = pred_latents[:observation.shape[1]].flatten(0, 1) # batch*jumps, * neg_latents = pred_latents[observation.shape[1]:].flatten(0, 1) latents = torch.cat([latents, neg_latents], 0) target_images = observation[self.time_offset:self.jumps + self.time_offset+1].transpose(0, 1).flatten(2, 3) # [16, 32, 4, 1, 84, 84] --> [32, 6, 4, 84, 84] target_images = self.transform(target_images, True) if not self.momentum_encoder and not self.shared_encoder: target_images = target_images[..., -1:, :, :] with torch.no_grad() if self.momentum_encoder else dummy_context_mgr(): target_latents = self.target_encoder(target_images.flatten(0, 1)) if self.renormalize: target_latents = self.renormalize_tensor(target_latents, first_dim=-3, target=True) target_latents = self.target_encoder_proj(target_latents) if self.local_spr: local_loss = self.local_spr_loss(latents, target_latents, observation) else: local_loss = 0 if self.global_spr: global_loss = self.global_spr_loss(latents, target_latents, observation) else: global_loss = 0 spr_loss = (global_loss + local_loss)/self.num_sprs spr_loss = spr_loss.view(-1, observation.shape[1]) # split to batch, jumps if self.momentum_encoder: update_state_dict(self.target_encoder, self.conv.state_dict(), self.momentum_tau) update_state_dict( self.target_renormalize_ln, self.renormalize_ln.state_dict(), 1.0 # we don't use momentum for ln ) update_state_dict( self.target_encoder_proj, self.conv_proj.state_dict(), self.momentum_tau ) if self.classifier_type != "bilinear": # q_l1 is also bilinear for local if self.local_spr and self.classifier_type != "q_l1": update_state_dict(self.local_target_classifier, self.local_classifier.state_dict(), self.momentum_tau) if self.global_spr: update_state_dict(self.global_target_classifier, self.global_classifier.state_dict(), self.momentum_tau) return spr_loss
def update_target(self, tau=1): update_state_dict(self.target_model, self.critic.state_dict(), tau)