Esempio n. 1
0
 def ul_optimize(self, itr, rl_samples=None):
     opt_info_ul = OptInfoUl(*([] for _ in range(len(OptInfoUl._fields))))
     n_ul_updates = self.compute_ul_update_schedule(itr)
     ul_bs = self.ul_batch_size
     n_rl_samples = (0 if rl_samples is None else len(
         rl_samples.agent_inputs.observation))
     for i in range(n_ul_updates):
         self.ul_update_counter += 1
         if self.ul_lr_scheduler is not None:
             self.ul_lr_scheduler.step(self.ul_update_counter)
         if n_rl_samples >= self.ul_batch_size * (i + 1):
             ul_samples = rl_samples[i * ul_bs:(i + 1) * ul_bs]
         else:
             ul_samples = None
         ul_loss, ul_accuracy, grad_norm = self.ul_optimize_one_step(
             ul_samples)
         opt_info_ul.ulLoss.append(ul_loss.item())
         opt_info_ul.ulAccuracy.append(ul_accuracy.item())
         opt_info_ul.ulGradNorm.append(grad_norm.item())
         if self.ul_update_counter % self.ul_target_update_interval == 0:
             update_state_dict(
                 self.ul_target_encoder,
                 self.ul_encoder.state_dict(),
                 self.ul_target_update_tau,
             )
     opt_info_ul.ulUpdates.append(self.ul_update_counter)
     return opt_info_ul
Esempio n. 2
0
 def optimize(self, itr):
     opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
     samples = self.replay_buffer.sample_batch(self.batch_size)
     if self.lr_scheduler is not None:
         self.lr_scheduler.step(itr)  # Do every itr instead of every epoch
     self.optimizer.zero_grad()
     stdim_loss, loss_vals, accuracies, conv_output = self.stdim_loss(
         samples)
     act_loss = self.activation_loss(conv_output)
     loss = stdim_loss + act_loss
     loss.backward()
     if self.clip_grad_norm is None:
         grad_norm = 0.0
     else:
         grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(),
                                                    self.clip_grad_norm)
     self.optimizer.step()
     opt_info.stdimLoss.append(stdim_loss.item())
     opt_info.ggLoss.append(loss_vals[0])
     opt_info.glLoss.append(loss_vals[1])
     opt_info.llLoss.append(loss_vals[2])
     opt_info.ggAccuracy.append(accuracies[0])
     opt_info.glAccuracy.append(accuracies[1])
     opt_info.llAccuracy.append(accuracies[2])
     opt_info.activationLoss.append(act_loss.item())
     opt_info.gradNorm.append(grad_norm.item())
     opt_info.convActivation.append(
         conv_output[0].detach().cpu().view(-1).numpy())  # Keep 1 full one.
     if itr % self.target_update_interval == 0:
         update_state_dict(self.target_encoder, self.encoder.state_dict(),
                           self.target_update_tau)
     return opt_info
Esempio n. 3
0
    def update_target(self, tau=1):
        """
        更新target network,即把main network(self.model)的参数拷贝到target network(self.target_model)上。
        当τ>0的时候会使用soft update算法来更新参数。
        为了保持learning过程的稳定性以及高效性,target network不是实时更新,而是周期性地更新一次。例如,在DQN.optimize_agent()函数中,
        会看到每隔一定的周期才会调用一次本函数的代码逻辑。

        :param tau: soft update算法里的τ参数。
        """
        update_state_dict(self.target_model, self.model.state_dict(), tau)
Esempio n. 4
0
 def ul_optimize(self, itr):
     opt_info_ul = OptInfoUl(*([] for _ in range(len(OptInfoUl._fields))))
     n_ul_updates = self.compute_ul_update_schedule(itr)
     for _ in range(n_ul_updates):
         self.ul_update_counter += 1
         if self.ul_lr_scheduler is not None:
             self.ul_lr_scheduler.step(self.ul_update_counter)
         ul_loss, ul_accuracy, grad_norm = self.ul_optimize_one_step()
         opt_info_ul.ulLoss.append(ul_loss.item())
         opt_info_ul.ulAccuracy.append(ul_accuracy.item())
         opt_info_ul.ulGradNorm.append(grad_norm.item())
         if self.ul_update_counter % self.ul_target_update_interval == 0:
             update_state_dict(self.ul_target_encoder, self.ul_encoder.state_dict(),
                 self.ul_target_update_tau)
     opt_info_ul.ulUpdates.append(self.ul_update_counter)
     return opt_info_ul
Esempio n. 5
0
 def update_targets(self, q_tau=1, encoder_tau=1):
     """Do each parameter ONLY ONCE."""
     update_state_dict(self.target_conv, self.conv.state_dict(),
                       encoder_tau)
     update_state_dict(self.target_q_fc1, self.q_fc1.state_dict(),
                       encoder_tau)
     update_state_dict(self.target_q_mlps, self.q_mlps.state_dict(), q_tau)
Esempio n. 6
0
 def update_target(self, tau=1):
     update_state_dict(self.target_model, self.model.state_dict(), tau)
     update_state_dict(self.target_q_model, self.q_model.state_dict(), tau)
Esempio n. 7
0
 def update_target(self, tau=1):
     update_state_dict(self.target_v_model, self.v_model.state_dict(), tau)
Esempio n. 8
0
 def update_target(self, tau=1):
     super().update_target(tau)
     update_state_dict(self.target_q2_model, self.q2_model.state_dict(),
                       tau)
Esempio n. 9
0
 def update_target(self, tau=1):
     """Copies the model parameters into the target model."""
     update_state_dict(self.target_model, self.model.state_dict(), tau)
Esempio n. 10
0
 def update_target(self, tau=1):
     [
         update_state_dict(target_q, q.state_dict(), tau)
         for target_q, q in zip(self.target_q_models, self.q_models)
     ]
Esempio n. 11
0
    def do_spr_loss(self, pred_latents, observation):
        # pred_latents.shape = [6,  32, 64, 7, 7] or [6, 32, 600]
        # observation.shape =  [16, 32, 4, 1, 84, 84]

        pred_latents = torch.stack(pred_latents, 1)
        latents = pred_latents[:observation.shape[1]].flatten(0, 1)  # batch*jumps, *
        neg_latents = pred_latents[observation.shape[1]:].flatten(0, 1)
        latents = torch.cat([latents, neg_latents], 0)
        target_images = observation[self.time_offset:self.jumps + self.time_offset+1].transpose(0, 1).flatten(2, 3)
        # [16, 32, 4, 1, 84, 84] --> [32, 6, 4, 84, 84]
        target_images = self.transform(target_images, True)

        if not self.momentum_encoder and not self.shared_encoder:
            target_images = target_images[..., -1:, :, :]
        with torch.no_grad() if self.momentum_encoder else dummy_context_mgr():
            target_latents = self.target_encoder(target_images.flatten(0, 1))
            if self.renormalize:
                target_latents = self.renormalize_tensor(target_latents, first_dim=-3, target=True)

        target_latents = self.target_encoder_proj(target_latents)

        if self.local_spr:
            local_loss = self.local_spr_loss(latents, target_latents, observation)
        else:
            local_loss = 0
        if self.global_spr:
            global_loss = self.global_spr_loss(latents, target_latents, observation)
        else:
            global_loss = 0

        spr_loss = (global_loss + local_loss)/self.num_sprs
        spr_loss = spr_loss.view(-1, observation.shape[1]) # split to batch, jumps

        if self.momentum_encoder:
            update_state_dict(self.target_encoder,
                              self.conv.state_dict(),
                              self.momentum_tau)

            update_state_dict(
                self.target_renormalize_ln,
                self.renormalize_ln.state_dict(),
                1.0 # we don't use momentum for ln
            )

            update_state_dict(
                self.target_encoder_proj,
                self.conv_proj.state_dict(),
                self.momentum_tau
            )

            if self.classifier_type != "bilinear":
                # q_l1 is also bilinear for local
                if self.local_spr and self.classifier_type != "q_l1":
                    update_state_dict(self.local_target_classifier,
                                      self.local_classifier.state_dict(),
                                      self.momentum_tau)
                if self.global_spr:
                    update_state_dict(self.global_target_classifier,
                                      self.global_classifier.state_dict(),
                                      self.momentum_tau)
        return spr_loss
Esempio n. 12
0
 def update_target(self, tau=1):
     update_state_dict(self.target_model, self.critic.state_dict(), tau)