def send_shared_memory(self): """Used in async mode only, in optimizer process; copies parameters from trained model (maybe GPU) to shared model, which the sampler can access. Does so under write-lock, and increments send-count which sampler can check. Typically called in the XXX during YY.""" if self.shared_model is not self.model: with self._rw_lock.write_lock: self.shared_model.load_state_dict( strip_ddp_state_dict(self.model.state_dict())) if self.dual_model: self.shared_model_int.load_state_dict( strip_ddp_state_dict(self.shared_model.state_dict())) self._send_count.value += 1
def sync_shared_memory(self): """Copies model parameters into shared_model, e.g. to make new values available to sampler workers. If running CPU-only, these will be the same object--no copy necessary. If model is on GPU, copy to CPU is performed. (Requires ``initialize(share_memory=True)`` called previously. Not used in async mode. Typically called in the XXX during YY. """ if self.shared_model is not self.model: # (self.model gets trained) self.shared_model.load_state_dict( strip_ddp_state_dict(self.model.state_dict())) if self.dual_model: self.shared_model_int.load_state_dict( strip_ddp_state_dict(self.model_int.state_dict()))
def send_shared_memory(self): """Used in async mode.""" if self.shared_model is not self.model: with self._rw_lock.write_lock: self.shared_model.load_state_dict( strip_ddp_state_dict(self.model.state_dict())) self._send_count.value += 1
def sync_shared_memory(self): """Call in sampler master (non-async), after initialize(share_memory=True).""" if self.shared_model is not self.model: # (self.model gets trained) self.shared_model.load_state_dict(strip_ddp_state_dict( self.model.state_dict()))
def update_target(self): self.target_model.load_state_dict( strip_ddp_state_dict(self.model.state_dict()))
def compute_beta_kl(self, loss_inputs, init_rnn_state, batch_size, mb_size, T): """Ratio of KL divergences from reward-only vs cost-only updates.""" self.agent.beta_r_model.load_state_dict( strip_ddp_state_dict(self.agent.model.state_dict())) self.agent.beta_c_model.load_state_dict( strip_ddp_state_dict(self.agent.model.state_dict())) self.beta_r_optimizer.load_state_dict(self.optimizer.state_dict()) self.beta_c_optimizer.load_state_dict(self.optimizer.state_dict()) recurrent = self.agent.recurrent for _ in range(self.beta_kl_epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=batch_size > mb_size): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T rnn_state = init_rnn_state[B_idxs] if recurrent else None self.beta_r_optimizer.zero_grad() self.beta_c_optimizer.zero_grad() beta_r_loss, beta_c_loss = self.beta_kl_losses( *loss_inputs[T_idxs, B_idxs], rnn_state) beta_r_loss.backward() _ = torch.nn.utils.clip_grad_norm_( self.agent.beta_r_model.parameters(), self.clip_grad_norm) self.beta_r_optimizer.step() beta_c_loss.backward() _ = torch.nn.utils.clip_grad_norm_( self.agent.beta_c_model.parameters(), self.clip_grad_norm) self.beta_c_optimizer.step() if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") with torch.no_grad(): r_dist_info, c_dist_info = self.agent.beta_dist_infos( *loss_inputs.agent_inputs, init_rnn_state) else: with torch.no_grad(): r_dist_info, c_dist_info = self.agent.beta_dist_infos( *loss_inputs.agent_inputs, init_rnn_state) dist = self.agent.distribution beta_r_KL = dist.mean_kl(new_dist_info=r_dist_info, old_dist_info=loss_inputs.old_dist_info, valid=loss_inputs.valid) beta_c_KL = dist.mean_kl(new_dist_info=c_dist_info, old_dist_info=loss_inputs.old_dist_info, valid=loss_inputs.valid) if self._ddp: beta_KLs = torch.stack([beta_r_KL, beta_c_KL]) beta_KLs = beta_KLs.to(self.agent.device) torch.distributed.all_reduce(beta_KLs) beta_KLs = beta_KLs.to("cpu") beta_KLs /= torch.distributed.get_world_size() beta_r_KL, beta_c_KL = beta_KLs[0], beta_KLs[1] raw_beta_KL = float(beta_r_KL / max(beta_c_KL, 1e-8)) return raw_beta_KL, float(beta_r_KL), float(beta_c_KL)