コード例 #1
0
    def send_shared_memory(self):
        """Used in async mode only, in optimizer process; copies parameters
        from trained model (maybe GPU) to shared model, which the sampler can
        access. Does so under write-lock, and increments send-count which sampler
        can check.

        Typically called in the XXX during YY."""
        if self.shared_model is not self.model:
            with self._rw_lock.write_lock:
                self.shared_model.load_state_dict(
                    strip_ddp_state_dict(self.model.state_dict()))
                if self.dual_model:
                    self.shared_model_int.load_state_dict(
                        strip_ddp_state_dict(self.shared_model.state_dict()))
                self._send_count.value += 1
コード例 #2
0
    def sync_shared_memory(self):
        """Copies model parameters into shared_model, e.g. to make new values
        available to sampler workers.  If running CPU-only, these will be the
        same object--no copy necessary.  If model is on GPU, copy to CPU is
        performed. (Requires ``initialize(share_memory=True)`` called
        previously.  Not used in async mode.

        Typically called in the XXX during YY.
        """
        if self.shared_model is not self.model:  # (self.model gets trained)
            self.shared_model.load_state_dict(
                strip_ddp_state_dict(self.model.state_dict()))
            if self.dual_model:
                self.shared_model_int.load_state_dict(
                    strip_ddp_state_dict(self.model_int.state_dict()))
コード例 #3
0
 def send_shared_memory(self):
     """Used in async mode."""
     if self.shared_model is not self.model:
         with self._rw_lock.write_lock:
             self.shared_model.load_state_dict(
                 strip_ddp_state_dict(self.model.state_dict()))
             self._send_count.value += 1
コード例 #4
0
 def sync_shared_memory(self):
     """Call in sampler master (non-async), after initialize(share_memory=True)."""
     if self.shared_model is not self.model:  # (self.model gets trained)
         self.shared_model.load_state_dict(strip_ddp_state_dict(
             self.model.state_dict()))
コード例 #5
0
ファイル: dqn_agent.py プロジェクト: Xingyu-Lin/softagent
 def update_target(self):
     self.target_model.load_state_dict(
         strip_ddp_state_dict(self.model.state_dict()))
コード例 #6
0
    def compute_beta_kl(self, loss_inputs, init_rnn_state, batch_size, mb_size,
                        T):
        """Ratio of KL divergences from reward-only vs cost-only updates."""
        self.agent.beta_r_model.load_state_dict(
            strip_ddp_state_dict(self.agent.model.state_dict()))
        self.agent.beta_c_model.load_state_dict(
            strip_ddp_state_dict(self.agent.model.state_dict()))
        self.beta_r_optimizer.load_state_dict(self.optimizer.state_dict())
        self.beta_c_optimizer.load_state_dict(self.optimizer.state_dict())

        recurrent = self.agent.recurrent
        for _ in range(self.beta_kl_epochs):
            for idxs in iterate_mb_idxs(batch_size,
                                        mb_size,
                                        shuffle=batch_size > mb_size):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                self.beta_r_optimizer.zero_grad()
                self.beta_c_optimizer.zero_grad()

                beta_r_loss, beta_c_loss = self.beta_kl_losses(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)

                beta_r_loss.backward()
                _ = torch.nn.utils.clip_grad_norm_(
                    self.agent.beta_r_model.parameters(), self.clip_grad_norm)
                self.beta_r_optimizer.step()

                beta_c_loss.backward()
                _ = torch.nn.utils.clip_grad_norm_(
                    self.agent.beta_c_model.parameters(), self.clip_grad_norm)
                self.beta_c_optimizer.step()

        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            with torch.no_grad():
                r_dist_info, c_dist_info = self.agent.beta_dist_infos(
                    *loss_inputs.agent_inputs, init_rnn_state)
        else:
            with torch.no_grad():
                r_dist_info, c_dist_info = self.agent.beta_dist_infos(
                    *loss_inputs.agent_inputs, init_rnn_state)

        dist = self.agent.distribution
        beta_r_KL = dist.mean_kl(new_dist_info=r_dist_info,
                                 old_dist_info=loss_inputs.old_dist_info,
                                 valid=loss_inputs.valid)
        beta_c_KL = dist.mean_kl(new_dist_info=c_dist_info,
                                 old_dist_info=loss_inputs.old_dist_info,
                                 valid=loss_inputs.valid)

        if self._ddp:
            beta_KLs = torch.stack([beta_r_KL, beta_c_KL])
            beta_KLs = beta_KLs.to(self.agent.device)
            torch.distributed.all_reduce(beta_KLs)
            beta_KLs = beta_KLs.to("cpu")
            beta_KLs /= torch.distributed.get_world_size()
            beta_r_KL, beta_c_KL = beta_KLs[0], beta_KLs[1]

        raw_beta_KL = float(beta_r_KL / max(beta_c_KL, 1e-8))

        return raw_beta_KL, float(beta_r_KL), float(beta_c_KL)