def update(self, replay_buffer, logger, step): total_actor_loss, total_alpha_loss, total_critic_loss = [], [], [] target_vs = [] irm_penalties = [] for env_id in range(self.num_envs): ( obs, action, reward, next_obs, not_done, not_done_no_max, ) = replay_buffer.sample(self.batch_size, env_id) logger.log("train/batch_reward", reward.mean(), step) critic_loss, target_v = self.update_critic( obs, action, reward, next_obs, not_done_no_max, logger, step ) total_critic_loss.append(critic_loss) target_vs.append(target_v) if step % self.actor_update_frequency == 0: actor_loss, alpha_loss = self.update_actor_and_alpha(obs, logger, step) total_actor_loss.append(actor_loss) total_alpha_loss.append(alpha_loss) irm_penalties.append(self.irm_penalty) # Optimize the critic train_penalty = torch.stack(irm_penalties).mean() penalty_weight = ( self.penalty_weight if step >= self.penalty_anneal_iters else 1.0 ) logger.log("train_encoder/penalty", train_penalty, step) total_critic_loss = torch.stack(total_critic_loss).mean() total_critic_loss += penalty_weight * train_penalty if penalty_weight > 1.0: # Rescale the entire loss to keep gradients in a reasonable range total_critic_loss /= penalty_weight self.critic_optimizer.zero_grad() total_critic_loss.backward() self.critic_optimizer.step() self.critic.log(logger, step) if step % self.actor_update_frequency == 0: # optimize the actor self.actor_optimizer.zero_grad() torch.stack(total_actor_loss).mean().backward() self.actor_optimizer.step() self.actor.log(logger, step) self.log_alpha_optimizer.zero_grad() torch.stack(total_alpha_loss).mean().backward() self.log_alpha_optimizer.step() if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
def update_curl(self, obs_anchor, obs_pos, L=None, step=None, ema=False): assert obs_anchor.shape[-1] == 84 and obs_pos.shape[-1] == 84 z_a = self.curl.encode(obs_anchor) z_pos = self.curl.encode(obs_pos, ema=True) logits = self.curl.compute_logits(z_a, z_pos) labels = torch.arange(logits.shape[0]).long().cuda() curl_loss = F.cross_entropy(logits, labels) self.encoder_optimizer.zero_grad() self.curl_optimizer.zero_grad() curl_loss.backward() self.encoder_optimizer.step() self.curl_optimizer.step() if L is not None: L.log('train/curl_loss', curl_loss, step) if ema: utils.soft_update_params( self.critic.encoder, self.critic_target.encoder, self.encoder_tau ) return curl_loss.item()
def update(self, replay_buffer, L, step): if step < 2000: for _ in range(2): obs, action, reward, next_obs, not_done = replay_buffer.sample( ) self.update_critic(obs, action, reward, next_obs, not_done, L, step) self.update_actor_and_alpha(obs, L, step) if step % self.log_interval == 0: L.log('train/batch_reward', reward.mean(), step) else: obs, action, reward, next_obs, not_done = replay_buffer.sample() if step % self.log_interval == 0: L.log('train/batch_reward', reward.mean(), step) self.MVE_prediction(replay_buffer, L, step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau)
def update_with_latent(self, latent_buffer_critic, latent_buffer_actor, L, step): obs, action, reward, next_obs, not_done, idxs, copy_nums = latent_buffer_critic.sample_proprio( ) obs_a, action_a, reward_a, next_obs_a, not_done_a = latent_buffer_actor.sample_proprio_with_idxs( idxs, copy_nums) if step % self.log_interval == 0: L.log('train/batch_reward', reward.mean(), step) # set flag to indicate detach everything before fc layer self.critic.encoder.detach_fc = True self.critic_target.encoder.detach_fc = True self.actor.encoder.detach_fc = True self.update_critic_with_latent(obs, action, reward, next_obs, not_done, obs_a, action_a, reward_a, next_obs_a, not_done_a, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha_with_latent(obs, obs_a, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau)
def update(self, train_dataloader, val_dataloader, early_stopper_contrastive, early_stopper_dynamics): #torch.cuda.empty_cache() # Releases cache so the GPU has more memory if early_stopper_contrastive.early_stop or early_stopper_dynamics.early_stop: print( 'early stopping-Early stopping contrastive, Early stopping dynamics :', early_stopper_contrastive.early_stop, early_stopper_dynamics.early_stop) return for step, (obs, actions, next_obs, cpc_kwargs) in enumerate(train_dataloader): obs, actions, next_obs = obs.to(self.device), actions.to( self.device), next_obs.to(self.device) if step % self.encoder_update_freq == 0: soft_update_params(self.CURL.encoder, self.CURL.encoder_target, self.encoder_tau) if step % self.cpc_update_freq == 0: obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs[ "obs_pos"] obs_anchor, obs_pos = obs_anchor.to(self.device), obs_pos.to( self.device) self.update_cpc( obs_anchor, obs_pos ) # Nawid - Performs the contrastive loss I believe if step % self.dynamics_update_freq == 0: self.update_dynamics(obs, actions, next_obs) self.validation(val_dataloader, early_stopper_contrastive, early_stopper_dynamics)
def update(self, replay_buffer, L, step): obs, action, reward, next_obs, not_done, obs2 = replay_buffer.sample() L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) # utils.soft_update_params( # self.critic.encoder, self.critic_target.encoder, # self.encoder_tau # ) # Previously used to use all three images # self.update_imm(torch.cat((obs[:, 0:3, :, :], # obs[:, 3:6, :, :], # obs[:, 6:9, :, :]), dim=0), # torch.cat((obs2[:, 0:3, :, :], # obs2[:, 3:6, :, :], # obs2[:, 6:9, :, :]), dim=0), # L, step) # train only on one of the three images - selected randomly each step img_idx = np.random.randint(3) self.update_imm(obs[:, (3 * img_idx):(3 * img_idx + 3), :, :], obs2[:, (3 * img_idx):(3 * img_idx + 3), :, :], L, step)
def update(self, replay_buffer, logger, step): total_actor_loss, total_alpha_loss, total_critic_loss, obses, env_ids = ( [], [], [], [], [], ) for env_id in range(self.num_envs): ( obs, action, reward, next_obs, not_done, not_done_no_max, ) = replay_buffer.sample(self.batch_size, env_id) obses.append(obs) env_ids.append(torch.ones_like(reward).long() * env_id) logger.log("train/batch_reward", reward.mean(), step) critic_loss = self.update_critic( obs, action, reward, next_obs, not_done_no_max, logger, step ) total_critic_loss.append(critic_loss) if step % self.actor_update_frequency == 0: actor_loss, alpha_loss = self.update_actor_and_alpha(obs, logger, step) total_actor_loss.append(actor_loss) total_alpha_loss.append(alpha_loss) self.update_decoder(obs, action, reward, next_obs, logger, step, env_id) # Optimize the critic self.critic_optimizer.zero_grad() torch.stack(total_critic_loss).mean().backward() self.critic_optimizer.step() self.critic.log(logger, step) # Optimize classifier self.update_classifier( torch.cat(obses, dim=0), torch.cat(env_ids, dim=0).squeeze() ) if step % self.actor_update_frequency == 0: # optimize the actor self.actor_optimizer.zero_grad() torch.stack(total_actor_loss).mean().backward() self.actor_optimizer.step() self.actor.log(logger, step) self.log_alpha_optimizer.zero_grad() torch.stack(total_alpha_loss).mean().backward() self.log_alpha_optimizer.step() if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
def update_neg_rad(self, x, neg, anchor, L=None, step=None): neg_loss = self.compute_neg_loss(x, neg, anchor) self.neg_optimizer.zero_grad() neg_loss.backward() self.neg_optimizer.step() utils.soft_update_params(self.predictor, self.predictor_target, self.soda_tau) if L is not None: L.log('train/neg_loss', neg_loss, step)
def update(self, replay_buffer, step): observation, desired_goal, action, reward, next_observation, not_done = replay_buffer.sample( self.batch_size) self.update_critic(observation, desired_goal, action, reward, next_observation, not_done, step) if step % self.actor_update_frequency == 0: self.update_actor_and_alpha(observation, desired_goal, step) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
def update(self, replay_buffer, logger, step): obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample( self.batch_size ) logger.log("train/batch_reward", reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done_no_max, logger, step) if step % self.actor_update_frequency == 0: self.update_actor_and_alpha(obs, logger, step) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
def update(self, replay_buffer, logger, step): obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample( self.batch_size) logger.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, obs_aug, action, reward, next_obs, next_obs_aug, not_done, logger, step) # if step % self.osl_update_frequency == 0: # for _ in range(2): # self.update_osl(obs, action, next_obs) # for _ in range(3): self.update_osl_traj(replay_buffer) if step % self.actor_update_frequency == 0: self.update_actor_and_alpha(obs, logger, step) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, 0.01) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, 0.01) utils.soft_update_params(self.osl.proj_online, self.osl.proj_momentum, 0.05) utils.soft_update_params(self.osl.encoder_online, self.osl.encoder_momentum, 0.05)
def pretrain(self, replay_buffer, step): # obs, action, reward, next_obs, not_done, obs_copy, next_obs_copy = replay_buffer.sample(self.batch_size) # self.update_osl(obs, action, next_obs) self.update_osl_traj(replay_buffer) # z = torch.FloatTensor(self.batch_size, self.critic.encoder.feature_dim).uniform_(0.8, 1.2).to(self.device) # z_two = torch.FloatTensor(self.batch_size, self.critic.encoder.feature_dim).uniform_(0.8, 1.2).to(self.device) # # self.update_osl(obs, action, next_obs, obs_copy, reward, z) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.osl.proj_online, self.osl.proj_momentum, 0.05) utils.soft_update_params(self.osl.encoder_online, self.osl.encoder_momentum, 0.05)
def update(self, replay_buffer, logger, step): obs, action, reward, next_obs, discount = replay_buffer.sample( self.batch_size, self.discount) logger.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, discount, logger, step) if step % self.actor_update_frequency == 0: self.update_actor(obs, logger, step) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
def update(self, replay_buffer, logger, step): obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug, idxs = replay_buffer.sample( self.batch_size, logger, step) logger.log('train/batch_reward', reward.mean(), step) priorities = self.update_critic(obs, obs_aug, action, reward, next_obs, next_obs_aug, not_done, logger, step) replay_buffer.update_priorities(idxs, priorities) if step % self.actor_update_frequency == 0: self.update_actor_and_alpha(obs, logger, step) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
def update(self, replay_buffer, logger, step): obs, action_vec, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample( self.batch_size) # print(type(obs), type(next_obs), obs.shape, next_obs.shape) logger.log('train/batch_reward', reward.mean(), step) self.fusion_optimizer.zero_grad() self.update_critic(obs, action_vec, reward, next_obs, not_done_no_max, logger, step) if step % self.actor_update_frequency == 0: self.update_actor_and_alpha(obs, logger, step) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_tau) self.fusion_optimizer.step()
def update(self, replay_buffer, L, step): obs, action, reward, next_obs, not_done, obs2 = replay_buffer.sample() obs = torch.zeros_like(obs).to(obs.device) next_obs = torch.zeros_like(next_obs).to(next_obs.device) L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params( self.critic.Q1, self.critic_target.Q1, self.critic_tau ) utils.soft_update_params( self.critic.Q2, self.critic_target.Q2, self.critic_tau )
def update_soda_same(self, x, L=None, step=None): assert x.size(-1) == 84 aug_x = x.clone() # x = augmentations.random_crop(x) # aug_x = augmentations.random_crop(aug_x) aug_x = augmentations.random_overlay(aug_x, self.args) soda_loss = self.compute_soda_loss(aug_x, x) self.soda_optimizer.zero_grad() soda_loss.backward() self.soda_optimizer.step() if L is not None: L.log('train/aux_loss', soda_loss, step) utils.soft_update_params(self.predictor, self.predictor_target, self.soda_tau)
def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step, ema=False): z_a = self.CURL.encode(obs_anchor) z_pos = self.CURL.encode(obs_pos, ema=True) logits = self.CURL.compute_logits(z_a, z_pos) labels = torch.arange(logits.shape[0]).long().to(self.device) loss = self.cross_entropy_loss(logits, labels) self.encoder_optimizer.zero_grad() self.cpc_optimizer.zero_grad() loss.backward() self.encoder_optimizer.step() self.cpc_optimizer.step() if step % self.log_interval == 0: L.log('train/curl_loss', loss, step) if ema: utils.soft_update_params(self.critic.encoder, self.critic_target.encoder, self.encoder_tau)
def update(self, replay_buffer, L, step, enc_train=True): obs, action, reward, next_obs, goal_obs, not_done = replay_buffer.sample_proprio( ) if enc_train == True: if self.decoder is not None and step % self.decoder_update_freq == 0: self.update_decoder(obs, obs, L, step) else: L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, goal_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, goal_obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) utils.soft_update_params(self.critic.encoder, self.critic_target.encoder, self.encoder_tau)
def update(self, replay_buffer, L, step): if self.encoder_type == 'pixel': obs, clean_obs, action, reward, next_obs, clean_next_obs, not_done = replay_buffer.sample_rad( self.augs_funcs) else: obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio( ) if step % self.log_interval == 0: L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, clean_obs, action, reward, next_obs, clean_next_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, clean_obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) utils.soft_update_params(self.critic.encoder, self.critic_target.encoder, self.encoder_tau)
def update(self, replay_buffer, L, step): if self.encoder_type == 'pixel': obs, action, reward, next_obs, not_done, cpc_kwargs = replay_buffer.sample_cpc( ) else: obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio( ) if step % self.log_interval == 0: L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) utils.soft_update_params(self.critic.encoder, self.critic_target.encoder, self.encoder_tau) if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel': obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs[ "obs_pos"] self.update_cpc(obs_anchor, obs_pos, cpc_kwargs, L, step)
def soft_update_critic_target(self): utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) utils.soft_update_params(self.critic.encoder, self.critic_target.encoder, self.encoder_tau)
def update(self, replay_buffer, L, step): if self.decoder_type == 'inverse': obs, action, reward, next_obs, not_done, k_obs = replay_buffer.sample( k=True) else: obs, action, _, reward, next_obs, not_done = replay_buffer.sample() L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) utils.soft_update_params(self.critic.encoder, self.critic_target.encoder, self.encoder_tau) if self.decoder is not None and step % self.decoder_update_freq == 0: # decoder_type is pixel self.update_decoder(obs, action, next_obs, L, step) if self.decoder_type == 'contrastive': self.update_contrastive(obs, action, next_obs, L, step) elif self.decoder_type == 'inverse': self.update_inverse(obs, action, k_obs, L, step)
def update(self, replay_buffer, L, step): obs, action, reward, next_obs, not_done, obs2 = replay_buffer.sample() L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params( self.critic.Q1, self.critic_target.Q1, self.critic_tau ) utils.soft_update_params( self.critic.Q2, self.critic_target.Q2, self.critic_tau ) utils.soft_update_params( self.critic.encoder, self.critic_target.encoder, self.encoder_tau ) if self.decoder is not None and step % self.decoder_update_freq == 0: self.update_decoder(obs, obs, L, step) self.update_imm(obs, obs2, L, step)
def update(self, replay_buffer, L, step): obs, action, _, reward, next_obs, not_done = replay_buffer.sample() L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) transition_reward_loss = self.update_transition_reward_model(obs, action, next_obs, reward, L, step) encoder_loss = self.update_encoder(obs, action, reward, L, step) total_loss = self.bisim_coef * encoder_loss + transition_reward_loss self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() total_loss.backward() self.encoder_optimizer.step() self.decoder_optimizer.step() if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params( self.critic.Q1, self.critic_target.Q1, self.critic_tau ) utils.soft_update_params( self.critic.Q2, self.critic_target.Q2, self.critic_tau ) utils.soft_update_params( self.critic.encoder, self.critic_target.encoder, self.encoder_tau )
def update(self, replay_buffer, L, step): if self.encoder_type == 'pixel': t0 = time.time() if self.augmix: obs, clean_obs, action, reward, next_obs, clean_next_obses, not_done = replay_buffer.sample_augmix() # clean obs will be used later when implementing jsd loss else: obs, action, reward, next_obs, not_done = replay_buffer.sample_rad(self.augs_funcs) t1 = time.time() # print(f"sampling done in {t1-t0:.3f}sec") else: obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio() if step % self.log_interval == 0: L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params( self.critic.Q1, self.critic_target.Q1, self.critic_tau ) utils.soft_update_params( self.critic.Q2, self.critic_target.Q2, self.critic_tau ) utils.soft_update_params( self.critic.encoder, self.critic_target.encoder, self.encoder_tau )
def update(self, replay_buffer, L, step): if self.use_curl: obs, action, reward, next_obs, not_done, curl_kwargs = replay_buffer.sample_curl() else: obs, action, reward, next_obs, not_done = replay_buffer.sample() L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params( self.critic.Q1, self.critic_target.Q1, self.critic_tau ) utils.soft_update_params( self.critic.Q2, self.critic_target.Q2, self.critic_tau ) utils.soft_update_params( self.critic.encoder, self.critic_target.encoder, self.encoder_tau ) if self.rot is not None and step % self.ss_update_freq == 0: self.update_rot(obs, L, step) if self.inv is not None and step % self.ss_update_freq == 0: self.update_inv(obs, next_obs, action, L, step) if self.curl is not None and step % self.ss_update_freq == 0: obs_anchor, obs_pos = curl_kwargs["obs_anchor"], curl_kwargs["obs_pos"] self.update_curl(obs_anchor, obs_pos, L, step)
def update_soda(self, replay_buffer, L=None, step=None): x = replay_buffer.sample_soda(self.soda_batch_size) assert x.size(-1) == 100 aug_x = x.clone() x = augmentations.random_crop(x) aug_x = augmentations.random_crop(aug_x) # print(x.shape, aug_x.shape) aug_x = augmentations.random_overlay(aug_x, self.args) # print(x.shape, aug_x.shape) soda_loss = self.compute_soda_loss(aug_x, x) self.soda_optimizer.zero_grad() soda_loss.backward() self.soda_optimizer.step() if L is not None: L.log('train/aux_loss', soda_loss, step) utils.soft_update_params(self.predictor, self.predictor_target, self.soda_tau)
def update(self, replay_buffer, step): if len(replay_buffer) < self.num_seed_steps: return obs, action, extr_reward, next_obs, discount = replay_buffer.sample( self.batch_size, self.discount) obs = self.aug(obs) next_obs = self.aug(next_obs) # train representation only during the task-agnostic phase if self.task_agnostic: if step % self.encoder_update_frequency == 0: self.update_repr(obs, next_obs, step) utils.soft_update_params(self.encoder, self.encoder_target, self.encoder_target_tau) with torch.no_grad(): intr_reward = self.compute_reward(next_obs, step) if self.task_agnostic: reward = intr_reward else: reward = extr_reward + self.intr_coef * intr_reward # decouple representation with torch.no_grad(): obs = self.encoder.encode(obs) next_obs = self.encoder.encode(next_obs) self.update_critic(obs, action, reward, next_obs, discount, step) if step % self.actor_update_frequency == 0: self.update_actor_and_alpha(obs, step) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
def update_sac(self, L, step, obs, action, reward, next_obs, not_done, log_networks): if step % self.log_interval == 0: L.log('train/batch_reward', reward.mean(), step) self.update_critic(obs, action, reward, next_obs, not_done, L, step, log_networks) if step % self.actor_update_freq == 0: self.update_actor_and_alpha(obs, L, step, log_networks) if step % self.critic_target_update_freq == 0: utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) utils.soft_update_params(self.critic.encoder, self.critic_target.encoder, self.encoder_tau)