示例#1
0
    def update(self, replay_buffer, logger, step):
        total_actor_loss, total_alpha_loss, total_critic_loss = [], [], []
        target_vs = []
        irm_penalties = []
        for env_id in range(self.num_envs):
            (
                obs,
                action,
                reward,
                next_obs,
                not_done,
                not_done_no_max,
            ) = replay_buffer.sample(self.batch_size, env_id)

            logger.log("train/batch_reward", reward.mean(), step)

            critic_loss, target_v = self.update_critic(
                obs, action, reward, next_obs, not_done_no_max, logger, step
            )
            total_critic_loss.append(critic_loss)
            target_vs.append(target_v)

            if step % self.actor_update_frequency == 0:
                actor_loss, alpha_loss = self.update_actor_and_alpha(obs, logger, step)
                total_actor_loss.append(actor_loss)
                total_alpha_loss.append(alpha_loss)

            irm_penalties.append(self.irm_penalty)

        # Optimize the critic
        train_penalty = torch.stack(irm_penalties).mean()
        penalty_weight = (
            self.penalty_weight if step >= self.penalty_anneal_iters else 1.0
        )
        logger.log("train_encoder/penalty", train_penalty, step)
        total_critic_loss = torch.stack(total_critic_loss).mean()
        total_critic_loss += penalty_weight * train_penalty
        if penalty_weight > 1.0:
            # Rescale the entire loss to keep gradients in a reasonable range
            total_critic_loss /= penalty_weight

        self.critic_optimizer.zero_grad()
        total_critic_loss.backward()
        self.critic_optimizer.step()
        self.critic.log(logger, step)

        if step % self.actor_update_frequency == 0:
            # optimize the actor
            self.actor_optimizer.zero_grad()
            torch.stack(total_actor_loss).mean().backward()
            self.actor_optimizer.step()

            self.actor.log(logger, step)

            self.log_alpha_optimizer.zero_grad()
            torch.stack(total_alpha_loss).mean().backward()
            self.log_alpha_optimizer.step()

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
示例#2
0
    def update_curl(self, obs_anchor, obs_pos, L=None, step=None, ema=False):
        assert obs_anchor.shape[-1] == 84 and obs_pos.shape[-1] == 84

        z_a = self.curl.encode(obs_anchor)
        z_pos = self.curl.encode(obs_pos, ema=True)
        
        logits = self.curl.compute_logits(z_a, z_pos)
        labels = torch.arange(logits.shape[0]).long().cuda()
        curl_loss = F.cross_entropy(logits, labels)
        
        self.encoder_optimizer.zero_grad()
        self.curl_optimizer.zero_grad()
        curl_loss.backward()

        self.encoder_optimizer.step()
        self.curl_optimizer.step()
        if L is not None:
            L.log('train/curl_loss', curl_loss, step)

        if ema:
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )

        return curl_loss.item()
示例#3
0
    def update(self, replay_buffer, L, step):

        if step < 2000:
            for _ in range(2):
                obs, action, reward, next_obs, not_done = replay_buffer.sample(
                )
                self.update_critic(obs, action, reward, next_obs, not_done, L,
                                   step)
                self.update_actor_and_alpha(obs, L, step)

            if step % self.log_interval == 0:
                L.log('train/batch_reward', reward.mean(), step)

        else:
            obs, action, reward, next_obs, not_done = replay_buffer.sample()

            if step % self.log_interval == 0:
                L.log('train/batch_reward', reward.mean(), step)

            self.MVE_prediction(replay_buffer, L, step)
            self.update_critic(obs, action, reward, next_obs, not_done, L,
                               step)
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     self.critic_tau)
示例#4
0
    def update_with_latent(self, latent_buffer_critic, latent_buffer_actor, L,
                           step):
        obs, action, reward, next_obs, not_done, idxs, copy_nums = latent_buffer_critic.sample_proprio(
        )
        obs_a, action_a, reward_a, next_obs_a, not_done_a = latent_buffer_actor.sample_proprio_with_idxs(
            idxs, copy_nums)

        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        # set flag to indicate detach everything before fc layer
        self.critic.encoder.detach_fc = True
        self.critic_target.encoder.detach_fc = True
        self.actor.encoder.detach_fc = True

        self.update_critic_with_latent(obs, action, reward, next_obs, not_done,
                                       obs_a, action_a, reward_a, next_obs_a,
                                       not_done_a, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha_with_latent(obs, obs_a, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     self.critic_tau)
示例#5
0
    def update(self, train_dataloader, val_dataloader,
               early_stopper_contrastive, early_stopper_dynamics):
        #torch.cuda.empty_cache() # Releases cache so the GPU has more memory
        if early_stopper_contrastive.early_stop or early_stopper_dynamics.early_stop:
            print(
                'early stopping-Early stopping contrastive, Early stopping dynamics :',
                early_stopper_contrastive.early_stop,
                early_stopper_dynamics.early_stop)
            return

        for step, (obs, actions, next_obs,
                   cpc_kwargs) in enumerate(train_dataloader):
            obs, actions, next_obs = obs.to(self.device), actions.to(
                self.device), next_obs.to(self.device)

            if step % self.encoder_update_freq == 0:
                soft_update_params(self.CURL.encoder, self.CURL.encoder_target,
                                   self.encoder_tau)
            if step % self.cpc_update_freq == 0:
                obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs[
                    "obs_pos"]
                obs_anchor, obs_pos = obs_anchor.to(self.device), obs_pos.to(
                    self.device)
                self.update_cpc(
                    obs_anchor, obs_pos
                )  # Nawid -  Performs the contrastive loss I believe

            if step % self.dynamics_update_freq == 0:
                self.update_dynamics(obs, actions, next_obs)

        self.validation(val_dataloader, early_stopper_contrastive,
                        early_stopper_dynamics)
示例#6
0
    def update(self, replay_buffer, L, step):
        obs, action, reward, next_obs, not_done, obs2 = replay_buffer.sample()

        L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     self.critic_tau)
            # utils.soft_update_params(
            #     self.critic.encoder, self.critic_target.encoder,
            #     self.encoder_tau
            # )

        # Previously used to use all three images
        # self.update_imm(torch.cat((obs[:, 0:3, :, :],
        #                            obs[:, 3:6, :, :],
        #                            obs[:, 6:9, :, :]), dim=0),
        #                 torch.cat((obs2[:, 0:3, :, :],
        #                            obs2[:, 3:6, :, :],
        #                            obs2[:, 6:9, :, :]), dim=0),
        #                 L, step)

        # train only on one of the three images - selected randomly each step
        img_idx = np.random.randint(3)
        self.update_imm(obs[:, (3 * img_idx):(3 * img_idx + 3), :, :],
                        obs2[:,
                             (3 * img_idx):(3 * img_idx + 3), :, :], L, step)
示例#7
0
    def update(self, replay_buffer, logger, step):
        total_actor_loss, total_alpha_loss, total_critic_loss, obses, env_ids = (
            [],
            [],
            [],
            [],
            [],
        )
        for env_id in range(self.num_envs):
            (
                obs,
                action,
                reward,
                next_obs,
                not_done,
                not_done_no_max,
            ) = replay_buffer.sample(self.batch_size, env_id)
            obses.append(obs)
            env_ids.append(torch.ones_like(reward).long() * env_id)

            logger.log("train/batch_reward", reward.mean(), step)

            critic_loss = self.update_critic(
                obs, action, reward, next_obs, not_done_no_max, logger, step
            )
            total_critic_loss.append(critic_loss)

            if step % self.actor_update_frequency == 0:
                actor_loss, alpha_loss = self.update_actor_and_alpha(obs, logger, step)
                total_actor_loss.append(actor_loss)
                total_alpha_loss.append(alpha_loss)

            self.update_decoder(obs, action, reward, next_obs, logger, step, env_id)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        torch.stack(total_critic_loss).mean().backward()
        self.critic_optimizer.step()
        self.critic.log(logger, step)

        # Optimize classifier
        self.update_classifier(
            torch.cat(obses, dim=0), torch.cat(env_ids, dim=0).squeeze()
        )

        if step % self.actor_update_frequency == 0:
            # optimize the actor
            self.actor_optimizer.zero_grad()
            torch.stack(total_actor_loss).mean().backward()
            self.actor_optimizer.step()

            self.actor.log(logger, step)

            self.log_alpha_optimizer.zero_grad()
            torch.stack(total_alpha_loss).mean().backward()
            self.log_alpha_optimizer.step()

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
示例#8
0
    def update_neg_rad(self, x, neg, anchor, L=None, step=None):
        neg_loss = self.compute_neg_loss(x, neg, anchor)

        self.neg_optimizer.zero_grad()
        neg_loss.backward()
        self.neg_optimizer.step()

        utils.soft_update_params(self.predictor, self.predictor_target,
                                 self.soda_tau)
        if L is not None:
            L.log('train/neg_loss', neg_loss, step)
示例#9
0
    def update(self, replay_buffer, step):
        observation, desired_goal, action, reward, next_observation, not_done = replay_buffer.sample(
            self.batch_size)

        self.update_critic(observation, desired_goal, action, reward,
                           next_observation, not_done, step)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(observation, desired_goal, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)
示例#10
0
    def update(self, replay_buffer, logger, step):
        obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample(
            self.batch_size
        )

        logger.log("train/batch_reward", reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done_no_max, logger, step)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, logger, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
示例#11
0
    def update(self, replay_buffer, logger, step):
        obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample(
            self.batch_size)

        logger.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, obs_aug, action, reward, next_obs,
                           next_obs_aug, not_done, logger, step)
        #
        if step % self.osl_update_frequency == 0:
            # for _ in range(2):
            # self.update_osl(obs, action, next_obs)
            # for _ in range(3):
            self.update_osl_traj(replay_buffer)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, logger, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     0.01)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     0.01)
            utils.soft_update_params(self.osl.proj_online,
                                     self.osl.proj_momentum, 0.05)
            utils.soft_update_params(self.osl.encoder_online,
                                     self.osl.encoder_momentum, 0.05)
示例#12
0
    def pretrain(self, replay_buffer, step):
        # obs, action, reward, next_obs, not_done, obs_copy, next_obs_copy = replay_buffer.sample(self.batch_size)

        # self.update_osl(obs, action, next_obs)
        self.update_osl_traj(replay_buffer)

        # z = torch.FloatTensor(self.batch_size, self.critic.encoder.feature_dim).uniform_(0.8, 1.2).to(self.device)
        # z_two = torch.FloatTensor(self.batch_size, self.critic.encoder.feature_dim).uniform_(0.8, 1.2).to(self.device)
        #
        # self.update_osl(obs, action, next_obs, obs_copy, reward, z)
        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.osl.proj_online,
                                     self.osl.proj_momentum, 0.05)
            utils.soft_update_params(self.osl.encoder_online,
                                     self.osl.encoder_momentum, 0.05)
示例#13
0
    def update(self, replay_buffer, logger, step):
        obs, action, reward, next_obs, discount = replay_buffer.sample(
            self.batch_size, self.discount)

        logger.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, discount, logger,
                           step)

        if step % self.actor_update_frequency == 0:
            self.update_actor(obs, logger, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)
示例#14
0
文件: drq.py 项目: HosseinSheikhi/drq
    def update(self, replay_buffer, logger, step):
        obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug, idxs = replay_buffer.sample(
            self.batch_size, logger, step)

        logger.log('train/batch_reward', reward.mean(), step)

        priorities = self.update_critic(obs, obs_aug, action, reward, next_obs,
                                        next_obs_aug, not_done, logger, step)
        replay_buffer.update_priorities(idxs, priorities)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, logger, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)
示例#15
0
    def update(self, replay_buffer, logger, step):
        obs, action_vec, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample(
            self.batch_size)
        # print(type(obs), type(next_obs), obs.shape, next_obs.shape)
        logger.log('train/batch_reward', reward.mean(), step)

        self.fusion_optimizer.zero_grad()
        self.update_critic(obs, action_vec, reward, next_obs, not_done_no_max,
                           logger, step)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, logger, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)
        self.fusion_optimizer.step()
示例#16
0
    def update(self, replay_buffer, L, step):
        obs, action, reward, next_obs, not_done, obs2 = replay_buffer.sample()
        obs = torch.zeros_like(obs).to(obs.device)
        next_obs = torch.zeros_like(next_obs).to(next_obs.device)
        L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
示例#17
0
    def update_soda_same(self, x, L=None, step=None):
        assert x.size(-1) == 84

        aug_x = x.clone()

        # x = augmentations.random_crop(x)
        # aug_x = augmentations.random_crop(aug_x)
        aug_x = augmentations.random_overlay(aug_x, self.args)

        soda_loss = self.compute_soda_loss(aug_x, x)

        self.soda_optimizer.zero_grad()
        soda_loss.backward()
        self.soda_optimizer.step()
        if L is not None:
            L.log('train/aux_loss', soda_loss, step)

        utils.soft_update_params(self.predictor, self.predictor_target,
                                 self.soda_tau)
示例#18
0
    def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step, ema=False):

        z_a = self.CURL.encode(obs_anchor)
        z_pos = self.CURL.encode(obs_pos, ema=True)

        logits = self.CURL.compute_logits(z_a, z_pos)
        labels = torch.arange(logits.shape[0]).long().to(self.device)
        loss = self.cross_entropy_loss(logits, labels)

        self.encoder_optimizer.zero_grad()
        self.cpc_optimizer.zero_grad()
        loss.backward()

        self.encoder_optimizer.step()
        self.cpc_optimizer.step()
        if step % self.log_interval == 0:
            L.log('train/curl_loss', loss, step)
        if ema:
            utils.soft_update_params(self.critic.encoder,
                                     self.critic_target.encoder,
                                     self.encoder_tau)
示例#19
0
文件: sac_ae.py 项目: s206283/gcrl
    def update(self, replay_buffer, L, step, enc_train=True):
        obs, action, reward, next_obs, goal_obs, not_done = replay_buffer.sample_proprio(
        )

        if enc_train == True:
            if self.decoder is not None and step % self.decoder_update_freq == 0:
                self.update_decoder(obs, obs, L, step)

        else:
            L.log('train/batch_reward', reward.mean(), step)

            self.update_critic(obs, action, reward, next_obs, goal_obs,
                               not_done, L, step)

            if step % self.actor_update_freq == 0:
                self.update_actor_and_alpha(obs, goal_obs, L, step)

            if step % self.critic_target_update_freq == 0:
                utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                         self.critic_tau)
                utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                         self.critic_tau)
                utils.soft_update_params(self.critic.encoder,
                                         self.critic_target.encoder,
                                         self.encoder_tau)
示例#20
0
    def update(self, replay_buffer, L, step):
        if self.encoder_type == 'pixel':
            obs, clean_obs, action, reward, next_obs, clean_next_obs, not_done = replay_buffer.sample_rad(
                self.augs_funcs)
        else:
            obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio(
            )

        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, clean_obs, action, reward, next_obs,
                           clean_next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, clean_obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.encoder,
                                     self.critic_target.encoder,
                                     self.encoder_tau)
示例#21
0
    def update(self, replay_buffer, L, step):
        if self.encoder_type == 'pixel':
            obs, action, reward, next_obs, not_done, cpc_kwargs = replay_buffer.sample_cpc(
            )
        else:
            obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio(
            )

        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.encoder,
                                     self.critic_target.encoder,
                                     self.encoder_tau)

        if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel':
            obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs[
                "obs_pos"]
            self.update_cpc(obs_anchor, obs_pos, cpc_kwargs, L, step)
示例#22
0
 def soft_update_critic_target(self):
     utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                              self.critic_tau)
     utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                              self.critic_tau)
     utils.soft_update_params(self.critic.encoder,
                              self.critic_target.encoder, self.encoder_tau)
    def update(self, replay_buffer, L, step):
        if self.decoder_type == 'inverse':
            obs, action, reward, next_obs, not_done, k_obs = replay_buffer.sample(
                k=True)
        else:
            obs, action, _, reward, next_obs, not_done = replay_buffer.sample()

        L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.encoder,
                                     self.critic_target.encoder,
                                     self.encoder_tau)

        if self.decoder is not None and step % self.decoder_update_freq == 0:  # decoder_type is pixel
            self.update_decoder(obs, action, next_obs, L, step)

        if self.decoder_type == 'contrastive':
            self.update_contrastive(obs, action, next_obs, L, step)
        elif self.decoder_type == 'inverse':
            self.update_inverse(obs, action, k_obs, L, step)
示例#24
0
    def update(self, replay_buffer, L, step):
        obs, action, reward, next_obs, not_done, obs2 = replay_buffer.sample()

        L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )

        if self.decoder is not None and step % self.decoder_update_freq == 0:
            self.update_decoder(obs, obs, L, step)
        
        self.update_imm(obs, obs2, L, step)
示例#25
0
    def update(self, replay_buffer, L, step):
        obs, action, _, reward, next_obs, not_done = replay_buffer.sample()

        L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)
        transition_reward_loss = self.update_transition_reward_model(obs, action, next_obs, reward, L, step)
        encoder_loss = self.update_encoder(obs, action, reward, L, step)
        total_loss = self.bisim_coef * encoder_loss + transition_reward_loss
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        total_loss.backward()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )
示例#26
0
    def update(self, replay_buffer, L, step):
        if self.encoder_type == 'pixel':
            t0 = time.time()
            if self.augmix:
                obs, clean_obs, action, reward, next_obs, clean_next_obses, not_done = replay_buffer.sample_augmix()
                # clean obs will be used later when implementing jsd loss
            else:
                obs, action, reward, next_obs, not_done = replay_buffer.sample_rad(self.augs_funcs)

            t1 = time.time()
            # print(f"sampling done in {t1-t0:.3f}sec")

        else:
            obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio()

        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )
示例#27
0
    def update(self, replay_buffer, L, step):
        if self.use_curl:
            obs, action, reward, next_obs, not_done, curl_kwargs = replay_buffer.sample_curl()
        else:
            obs, action, reward, next_obs, not_done = replay_buffer.sample()
        
        L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )
        
        if self.rot is not None and step % self.ss_update_freq == 0:
            self.update_rot(obs, L, step)

        if self.inv is not None and step % self.ss_update_freq == 0:
            self.update_inv(obs, next_obs, action, L, step)

        if self.curl is not None and step % self.ss_update_freq == 0:
            obs_anchor, obs_pos = curl_kwargs["obs_anchor"], curl_kwargs["obs_pos"]
            self.update_curl(obs_anchor, obs_pos, L, step)
示例#28
0
    def update_soda(self, replay_buffer, L=None, step=None):
        x = replay_buffer.sample_soda(self.soda_batch_size)
        assert x.size(-1) == 100

        aug_x = x.clone()

        x = augmentations.random_crop(x)
        aug_x = augmentations.random_crop(aug_x)
        # print(x.shape, aug_x.shape)
        aug_x = augmentations.random_overlay(aug_x, self.args)
        # print(x.shape, aug_x.shape)

        soda_loss = self.compute_soda_loss(aug_x, x)

        self.soda_optimizer.zero_grad()
        soda_loss.backward()
        self.soda_optimizer.step()
        if L is not None:
            L.log('train/aux_loss', soda_loss, step)

        utils.soft_update_params(self.predictor, self.predictor_target,
                                 self.soda_tau)
示例#29
0
    def update(self, replay_buffer, step):
        if len(replay_buffer) < self.num_seed_steps:
            return

        obs, action, extr_reward, next_obs, discount = replay_buffer.sample(
            self.batch_size, self.discount)

        obs = self.aug(obs)
        next_obs = self.aug(next_obs)

        # train representation only during the task-agnostic phase
        if self.task_agnostic:
            if step % self.encoder_update_frequency == 0:
                self.update_repr(obs, next_obs, step)

                utils.soft_update_params(self.encoder, self.encoder_target,
                                         self.encoder_target_tau)

        with torch.no_grad():
            intr_reward = self.compute_reward(next_obs, step)

        if self.task_agnostic:
            reward = intr_reward
        else:
            reward = extr_reward + self.intr_coef * intr_reward

        # decouple representation
        with torch.no_grad():
            obs = self.encoder.encode(obs)
            next_obs = self.encoder.encode(next_obs)

        self.update_critic(obs, action, reward, next_obs, discount, step)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_target_tau)
示例#30
0
    def update_sac(self, L, step, obs, action, reward, next_obs, not_done,
                   log_networks):
        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step,
                           log_networks)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, L, step, log_networks)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.encoder,
                                     self.critic_target.encoder,
                                     self.encoder_tau)