def optim_initialize(self, rank=0): self.rank = rank model = self.agent.model self.model_modules = [ model.observation_encoder, model.observation_decoder, model.reward_model, model.representation, model.transition ] if self.use_pcont: self.model_modules += [model.pcont] self.actor_modules = [model.action_decoder] self.value_modules = [model.value_model] self.model_optimizer = torch.optim.Adam(get_parameters( self.model_modules), lr=self.model_lr, **self.optim_kwargs) self.actor_optimizer = torch.optim.Adam(get_parameters( self.actor_modules), lr=self.actor_lr, **self.optim_kwargs) self.value_optimizer = torch.optim.Adam(get_parameters( self.value_modules), lr=self.value_lr, **self.optim_kwargs) if self.initial_optim_state_dict is not None: self.load_optim_state_dict(self.initial_optim_state_dict) # must define these fields to for logging purposes. Used by runner. self.opt_info_fields = OptInfo._fields
def test_get_parameters(): linear_module_1 = nn.Linear(4, 3) linear_module_2 = nn.Linear(3, 2) params = get_parameters([linear_module_1]) assert len(params) == 2 params = get_parameters([linear_module_1, linear_module_2]) assert len(params) == 4
def optimize_agent(self, itr, samples=None, sampler_itr=None): itr = itr if sampler_itr is None else sampler_itr if samples is not None: # Note: discount not saved here self.replay_buffer.append_samples(samples_to_buffer(samples)) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.prefill: return opt_info if itr % self.train_every != 0: return opt_info for i in tqdm(range(self.train_steps), desc='Imagination'): samples_from_replay = self.replay_buffer.sample_batch( self._batch_size, self.batch_length) buffed_samples = buffer_to(samples_from_replay, self.agent.device) model_loss, actor_loss, value_loss, loss_info = self.loss( buffed_samples, itr, i) self.model_optimizer.zero_grad() self.actor_optimizer.zero_grad() self.value_optimizer.zero_grad() model_loss.backward() actor_loss.backward() value_loss.backward() grad_norm_model = torch.nn.utils.clip_grad_norm_( get_parameters(self.model_modules), self.grad_clip) grad_norm_actor = torch.nn.utils.clip_grad_norm_( get_parameters(self.actor_modules), self.grad_clip) grad_norm_value = torch.nn.utils.clip_grad_norm_( get_parameters(self.value_modules), self.grad_clip) self.model_optimizer.step() self.actor_optimizer.step() self.value_optimizer.step() with torch.no_grad(): loss = model_loss + actor_loss + value_loss opt_info.loss.append(loss.item()) if isinstance(grad_norm_model, torch.Tensor): opt_info.grad_norm_model.append(grad_norm_model.item()) opt_info.grad_norm_actor.append(grad_norm_actor.item()) opt_info.grad_norm_value.append(grad_norm_value.item()) else: opt_info.grad_norm_model.append(grad_norm_model) opt_info.grad_norm_actor.append(grad_norm_actor) opt_info.grad_norm_value.append(grad_norm_value) for field in loss_info_fields: if hasattr(opt_info, field) \ and getattr(loss_info, field) is not None: getattr(opt_info, field).append( getattr(loss_info, field).item()) return opt_info