예제 #1
0
    def select_actions(self, obs, raw_context):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1)))
        if len(raw_context) == 0:
            # In the beginning, the inferred_mdp is set to zero vector.
            inferred_mdp = ptu.zeros((1, self.f.latent_dim))
        else:
            # Construct the context from raw context
            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            inferred_mdp = self.f(context)
        with torch.no_grad():
            inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1)
            z = from_numpy(
                np.random.normal(0, 1, size=(obs.size(0),
                                             self.vae_latent_dim))).clamp(
                                                 -0.5, 0.5).to(ptu.device)
            candidate_actions = self.vae_decoder(obs, z, inferred_mdp)
            perturbed_actions = self.perturbation_generator.get_perturbed_actions(
                obs, candidate_actions, inferred_mdp)
            qv = self.Qs(obs, perturbed_actions, inferred_mdp)
            ind = qv.max(0)[1]
        return ptu.get_numpy(perturbed_actions[ind])
예제 #2
0
def _get_prod_of_gauss_mask(num_selected, desired_len):

    # Taken from
    # https://discuss.pytorch.org/t/create-a-2d-tensor-with-varying-lengths-of-one-in-each-row/25359

    # desired_length is the desired size of the second dimension of the masks

    seq_lens = ptu.from_numpy(np.array(num_selected)).unsqueeze(-1)
    max_len = torch.max(seq_lens)

    # create tensor of suitable shape and same number of dimensions
    range_tensor = torch.arange(max_len).unsqueeze(0)
    range_tensor = range_tensor.to(ptu.device)
    range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1))

    # until this step, we only created auxiliary tensors (you may already have from previous steps)
    # the real mask tensor is created with binary masking:
    mask_tensor = (range_tensor < seq_lens)

    mask_tensor = mask_tensor.type(torch.float)

    current_len = mask_tensor.shape[1]

    pad = ptu.zeros(mask_tensor.shape[0], desired_len - current_len)

    mask_tensor = torch.cat((mask_tensor, pad), dim=1)

    return mask_tensor
예제 #3
0
    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (self.normal_mean + self.normal_std *
             Normal(ptu.zeros(self.normal_mean.size()),
                    ptu.ones(self.normal_std.size())).sample())
        z.requires_grad_()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)
 def compute_kl_div(self):
     ''' compute KL( q(z|c) || r(z) ) '''
     prior = torch.distributions.Normal(ptu.zeros(self.latent_dim),
                                        ptu.ones(self.latent_dim))
     posteriors = [
         torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip(
             torch.unbind(self.z_means), torch.unbind(self.z_vars))
     ]
     kl_divs = [
         torch.distributions.kl.kl_divergence(post, prior)
         for post in posteriors
     ]
     kl_div_sum = torch.sum(torch.stack(kl_divs))
     return kl_div_sum
    def clear_z(self, num_tasks=1):
        '''
        reset q(z|c) to the prior
        sample a new z from the prior
        '''
        # reset distribution over z to the prior
        mu = ptu.zeros(num_tasks, self.latent_dim)
        var = ptu.ones(num_tasks, self.latent_dim)

        self.z_means = mu
        self.z_vars = var
        # sample a new z from the prior
        self.sample_z()
        # reset the context collected so far
        self.context = None
예제 #6
0
    def select_actions(self, obs, raw_context):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        if len(raw_context) == 0:
            # In the beginning, the inferred_mdp is set to zero vector.
            inferred_mdp = ptu.zeros(
                (1, self.policy.mlp_encoder.encoder_latent_dim))
        else:
            # Construct the context from raw context
            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            inferred_mdp = self.policy.mlp_encoder(context)

        # obs = torch.cat([obs, inferred_mdp], dim=1)
        action = self.policy.select_action(obs, get_numpy(inferred_mdp))

        return action
예제 #7
0
    def __init__(
        self,
        policy_producer,
        qf1,
        target_qf1,
        qf2,
        target_qf2,
        lr,
        action_space=None,
        discount=0.99,
        reward_scale=1.0,
        optimizer_class=optim.Adam,
        soft_target_tau_qf=5e-3,
        soft_target_tau_policy=1e-2,
        target_update_period=1,
        use_automatic_entropy_tuning=True,
        target_entropy=None,
    ):
        super().__init__()
        """
        The class state which should not mutate
        """
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                # heuristic value from Tuomas
                self.target_entropy = - \
                    np.prod(action_space.shape).item()

        self.soft_target_tau_qf = soft_target_tau_qf
        self.soft_target_tau_policy = soft_target_tau_policy

        self.target_update_period = target_update_period

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.discount = discount
        self.reward_scale = reward_scale
        """
        The class mutable state
        """

        self.policy = policy_producer()
        self.target_policy = policy_producer()
        self.target_policy.load_state_dict(self.policy.state_dict())

        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2

        if self.use_automatic_entropy_tuning:
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=3e-4,
            )

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=lr,
        )

        self.policy_imitation_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=3e-4,
        )

        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=lr,
        )

        print('----------------------------------')
        print('qf_optimizer learning rate: ', lr)
        print('soft_target_tau_qf: ', soft_target_tau_qf)
        print('soft_target_tau_policy: ', soft_target_tau_policy)
        print('----------------------------------')

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
예제 #8
0
    def train(self, batch, batch_idxes):
        """
        Unpack data from the batch
        """
        obs = batch['obs']
        actions = batch['actions']
        contexts = batch['contexts']

        num_candidate_context = contexts[0].shape[0]
        meta_batch_size = batch_idxes.shape[0]
        num_posterior = meta_batch_size * num_candidate_context
        contexts = torch.cat(contexts, dim=0)

        # Get the in_mdp_batch_size
        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]

        # Sample z for each state
        z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device)

        target_q = []
        target_candidates = []
        target_perturbations = []

        for i, batch_idx in enumerate(batch_idxes):

            tq = self.bcq_polices[batch_idx].critic.q1(
                obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size],
                actions[i * in_mdp_batch_size:(i + 1) *
                        in_mdp_batch_size]).detach()
            target_q.append(tq)

            tc = self.bcq_polices[batch_idx].vae.decode(
                obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size],
                z[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size]).detach()
            target_candidates.append(tc)

            tp = self.bcq_polices[batch_idx].get_perturbation(
                obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size],
                tc).detach()
            target_perturbations.append(tp)

        target_q = torch.cat(target_q, dim=0).squeeze()
        target_candidates = torch.cat(target_candidates, dim=0)
        target_perturbations = torch.cat(target_perturbations, dim=0)

        gt.stamp('get_the_targets', unique=False)
        """
        Compute triplet loss
        """
        self.context_encoder_optimizer.zero_grad()

        # z_means, z_var: (num_posterior, latent_dim), num_posterior = meta_batch_size * num_candidate_context
        z_means, z_vars = self.context_encoder.infer_posterior_with_mean_var(
            contexts)

        # z_means_interleave: (num_posterior * num_posterior, latent_dim) [1, 2, 3] -> [1, 1, 1, 2, 2, 2, 3, 3, 3]
        z_means_interleave = torch.repeat_interleave(z_means,
                                                     num_posterior,
                                                     dim=0)
        # z_means_repeat: (num_posterior * num_posterior, latent_dim) [1, 2, 3] -> [1, 2, 3, 2, 3, 1, 3, 1, 2].
        # By doing so, it is easy to get the triplet loss
        z_means_repeat = []
        for i in range(meta_batch_size):
            z_means_repeat.append(
                torch.cat([
                    z_means[i * num_candidate_context:],
                    z_means[:i * num_candidate_context]
                ],
                          dim=0).repeat(num_candidate_context, 1))
        z_means_repeat = torch.cat(z_means_repeat, dim=0)

        # As above
        z_vars_interleave = torch.repeat_interleave(z_vars,
                                                    num_posterior,
                                                    dim=0)
        z_vars_repeat = []
        for i in range(meta_batch_size):
            z_vars_repeat.append(
                torch.cat([
                    z_vars[i * num_candidate_context:],
                    z_vars[:i * num_candidate_context]
                ],
                          dim=0).repeat(num_candidate_context, 1))
        z_vars_repeat = torch.cat(z_vars_repeat, dim=0)

        gt.stamp('get_repeated_mean_var', unique=False)

        # log(det(Sigma2) / det(Sigma1)): (num_posterior * num_posterior, 1)
        kl_divergence = torch.log(
            torch.prod(z_vars_repeat / z_vars_interleave, dim=1))
        # -d
        kl_divergence -= z_means.shape[-1]
        # Tr(Sigma2^{-1} * Sigma1)
        kl_divergence += torch.sum(z_vars_interleave / z_vars_repeat, dim=1)
        # (m2 - m1).T Sigma2^{-1} (m2 - m1))
        kl_divergence += torch.sum(
            (z_means_repeat - z_means_interleave)**2 / z_vars_repeat, dim=1)
        # / 2
        # (num_posterior, num_posterior): each element kl_{i, j} denotes the kl divergence between the two distributions.
        # Task number for row: i // num_posterior // num_candidate_context.
        #             for col: j % num_posterior // num_candidate_context.
        # Batch number for row: i // num_posterior % num_candidate_context.
        #              for col: j % num_posterior % num_candidate_context.
        kl_divergence = kl_divergence.reshape(num_posterior, num_posterior) / 2

        within_task_dist = torch.max(kl_divergence[:, :num_candidate_context],
                                     dim=1)[0]
        across_task_dist = torch.min(kl_divergence[:, num_candidate_context:],
                                     dim=1)[0]

        unscaled_triplet_loss = torch.sum(
            F.relu(within_task_dist - across_task_dist + self.triplet_margin))

        gt.stamp('get_triplet_loss', unique=False)
        """
        Infer the context variables
        """
        index = np.random.choice(
            num_candidate_context, meta_batch_size
        ) + num_candidate_context * np.arange(meta_batch_size)
        # Get the sampled mean and vars for each task.
        # mean: (meta_batch_size, latent_dim)
        # var: (meta_batch_size, latent_dim)
        mean = z_means[index]
        var = z_vars[index]

        # Get the inferred MDP
        # inferred_mdps: (meta_batch_size, latent_dim)
        inferred_mdps = self.context_encoder.sample_z_from_mean_var(mean, var)

        inferred_mdps = torch.repeat_interleave(inferred_mdps,
                                                in_mdp_batch_size,
                                                dim=0)

        gt.stamp('infer_mdps', unique=False)
        """
        Obtain the KL loss
        """
        prior_mean = ptu.zeros(mean.shape)
        prior_var = ptu.ones(var.shape)

        kl_loss = self.kl_lambda * self.context_encoder.compute_kl_div_between_posterior(
            mean, var, prior_mean, prior_var)

        gt.stamp('get_kl_loss', unique=False)

        # triplet_loss = (kl_loss / unscaled_triplet_loss).detach() * unscaled_triplet_loss
        # posterior_loss = unscaled_triplet_loss + kl_loss
        # posterior_loss.backward(retain_graph=True)

        # gt.stamp('get_posterior_gradient', unique=False)
        """
        Obtain the Q-function loss
        """
        self.Qs_optimizer.zero_grad()
        pred_q = self.Qs(obs, actions, inferred_mdps)
        pred_q = torch.squeeze(pred_q)
        qf_loss = F.mse_loss(pred_q, target_q)

        gt.stamp('get_qf_loss', unique=False)

        (qf_loss + unscaled_triplet_loss + kl_loss).backward()

        gt.stamp('get_qf_encoder_gradient', unique=False)

        self.Qs_optimizer.step()
        self.context_encoder_optimizer.step()
        """
        Obtain the candidate action and perturbation loss
        """

        self.vae_decoder_optimizer.zero_grad()
        self.perturbation_generator_optimizer.zero_grad()

        pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach())
        pred_perturbations = self.perturbation_generator(
            obs, target_candidates, inferred_mdps.detach())

        candidate_loss = F.mse_loss(pred_candidates, target_candidates)
        perturbation_loss = F.mse_loss(pred_perturbations,
                                       target_perturbations)

        gt.stamp('get_candidate_and_perturbation_loss', unique=False)

        candidate_loss.backward()
        perturbation_loss.backward()

        gt.stamp('get_candidate_and_perturbation_gradient', unique=False)

        self.vae_decoder_optimizer.step()
        self.perturbation_generator_optimizer.step()
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['unscaled_triplet_loss'] = np.mean(
                ptu.get_numpy(unscaled_triplet_loss))
            self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss))
            self.eval_statistics['candidate_loss'] = np.mean(
                ptu.get_numpy(candidate_loss))
            self.eval_statistics['perturbation_loss'] = np.mean(
                ptu.get_numpy(perturbation_loss))
예제 #9
0
    def train(self, batch, batch_idxes):
        """
        Unpack data from the batch
        """
        obs = batch['obs']
        actions = batch['actions']
        contexts = batch['contexts']

        num_tasks = batch_idxes.shape[0]

        gt.stamp('unpack_data_from_the_batch', unique=False)

        # Get the in_mdp_batch_size
        obs_dim = obs.shape[1]
        action_dim = actions.shape[1]
        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]
        num_trans_context = contexts.shape[0] // batch_idxes.shape[0]
        """
        Relabel the context batches for each training task
        """
        with torch.no_grad():
            contexts_obs_actions = contexts[:, :obs_dim + action_dim]

            manual_batched_rewards = self.reward_ensemble_predictor.forward_mul_device(
                contexts_obs_actions)

            relabeled_rewards = manual_batched_rewards.reshape(
                num_tasks, self.num_network_ensemble, contexts.shape[0])

            gt.stamp('reward_ensemble_forward', unique=False)

            manual_batched_next_obs = self.transition_ensemble_predictor.forward_mul_device(
                contexts_obs_actions)

            relabeled_next_obs = manual_batched_next_obs.reshape(
                num_tasks, self.num_network_ensemble, contexts.shape[0],
                obs_dim)

            gt.stamp('transition_ensemble_forward', unique=False)

            relabeled_rewards_mean = torch.mean(relabeled_rewards,
                                                dim=1).squeeze()
            relabeled_rewards_std = torch.std(relabeled_rewards,
                                              dim=1).squeeze()

            relabeled_next_obs_mean = torch.mean(relabeled_next_obs, dim=1)

            relabeled_next_obs_std = torch.std(relabeled_next_obs, dim=1)
            relabeled_next_obs_std = torch.mean(relabeled_next_obs_std, dim=-1)

            # Replace the predicted reward with ground truth reward for transitions
            # with ground truth reward inside the batch
            for i in range(num_tasks):

                relabeled_rewards_mean[i, i*num_trans_context: (i+1)*num_trans_context] \
                    = contexts[i*num_trans_context: (i+1)*num_trans_context, -1]

                relabeled_next_obs_mean[i, i*num_trans_context: (i+1)*num_trans_context, :] \
                    = contexts[i*num_trans_context: (i+1)*num_trans_context, obs_dim + action_dim : -1]

                if self.is_combine:
                    # Set the number to be larger than the self.std_threshold, so that
                    # they will initially be filtered out when producing the mask,
                    # which is conducive to the sampling.
                    relabeled_rewards_std[
                        i, i * num_trans_context:(i + 1) *
                        num_trans_context] = self.reward_std_threshold + 1.0
                    relabeled_next_obs_std[
                        i, i * num_trans_context:(i + 1) *
                        num_trans_context] = self.next_obs_std_threshold + 1.0
                else:
                    relabeled_rewards_std[i, i * num_trans_context:(i + 1) *
                                          num_trans_context] = 0.0
                    relabeled_next_obs_std[i, i * num_trans_context:(i + 1) *
                                           num_trans_context] = 0.0

            mask_reward = relabeled_rewards_std < self.reward_std_threshold
            mask_reward = mask_reward.type(torch.float)

            mask_next_obs = relabeled_next_obs_std < self.next_obs_std_threshold
            mask_next_obs = mask_next_obs.type(torch.float)

            mask = mask_reward * mask_next_obs
            mask = mask.type(torch.uint8)

            mask_from_the_other_tasks = mask.type(torch.uint8).clone()

            num_context_candidate_each_task = torch.sum(mask, dim=1)

            mask_list = []

            for i in range(num_tasks):

                assert mask[i].dim() == 1

                mask_nonzero = torch.nonzero(mask[i])
                mask_nonzero = mask_nonzero.flatten()

                mask_i = ptu.zeros_like(mask[i], dtype=torch.uint8)

                assert num_context_candidate_each_task[i].item(
                ) == mask_nonzero.shape[0]

                np_ind = np.random.choice(mask_nonzero.shape[0],
                                          num_trans_context,
                                          replace=False)

                ind = mask_nonzero[np_ind]

                mask_i[ind] = 1

                if self.is_combine:
                    # Combine the additional relabeledcontext transitions with
                    # the original context transitions with ground-truth rewards
                    mask_i[i * num_trans_context:(i + 1) *
                           num_trans_context] = 1
                    assert torch.sum(mask_i).item() == 2 * num_trans_context
                else:
                    assert torch.sum(mask_i).item() == num_trans_context

                mask_list.append(mask_i)

            mask = torch.cat(mask_list)
            mask = mask.type(torch.uint8)

            repeated_contexts = contexts.repeat(num_tasks, 1)
            context_without_next_obs_rewards = repeated_contexts[:, :obs_dim +
                                                                 action_dim]

            assert context_without_next_obs_rewards.shape[
                0] == relabeled_rewards_mean.reshape(-1, 1).shape[0]
            assert context_without_next_obs_rewards.shape[
                0] == relabeled_next_obs_mean.reshape(-1, obs_dim).shape[0]

            context_without_next_obs_rewards = context_without_next_obs_rewards[
                mask]

            context_next_obs = relabeled_next_obs_mean.reshape(-1,
                                                               obs_dim)[mask]

            context_rewards = relabeled_rewards_mean.reshape(-1, 1)[mask]

            fast_contexts = torch.cat((context_without_next_obs_rewards,
                                       context_next_obs, context_rewards),
                                      dim=1)

            fast_contexts = fast_contexts.reshape(num_tasks, -1,
                                                  contexts.shape[-1])

        gt.stamp('relabel_context_transitions', unique=False)
        """
        Obtain the targets
        """
        with torch.no_grad():
            # Sample z for each state
            z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device)

            # Each item in critic_weights is a list that has device count entries
            # each entry in the critic_weights[i] is a list that has num layer entries
            # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size)
            # Similarly to the other weights and biases
            critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies

            # CRITIC
            obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size,
                                           -1)

            obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1)

            target_q = batch_bcq(obs_acs_reshaped, critic_weights,
                                 critic_biases)
            target_q = target_q.reshape(-1)

            # VAE
            z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1)

            tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases)
            tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc)
            target_candidates = tc.reshape(-1, tc.shape[-1])

            # PERTURBATION
            tc_reshaped = target_candidates.reshape(len(batch_idxes),
                                                    in_mdp_batch_size, -1)

            obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1)

            tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases)
            tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp)
            target_perturbations = tp.reshape(-1, tp.shape[-1])

        gt.stamp('get_the_targets', unique=False)
        """
        Compute the triplet loss
        """
        # ----------------------------------Vectorized-------------------------------------------
        self.context_encoder_optimizer.zero_grad()

        anchors = []
        positives = []
        negatives = []

        num_selected_list = []

        # Pair of task (i,j)
        # where no transitions from j is selected by the ensemble of task i
        exclude_tasks = []

        exclude_task_masks = []

        for i in range(num_tasks):
            # Compute the triplet loss for task i

            for j in range(num_tasks):

                if j != i:

                    # mask_for_task_j: (num_trans_context, )
                    # mask_from_the_other_tasks: (num_tasks, num_tasks * num_trans_context)
                    mask_for_task_j = mask_from_the_other_tasks[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    num_selected = int(torch.sum(mask_for_task_j).item())

                    if num_selected == 0:
                        exclude_tasks.append((i, j))
                        exclude_task_masks.append(0)
                    else:
                        exclude_task_masks.append(1)

                    # context_trans_all: (num_trans_context, context_dim)
                    context_trans_all = contexts[j *
                                                 num_trans_context:(j + 1) *
                                                 num_trans_context]
                    # context_trans_all: (num_selected, context_dim)
                    context_trans_selected = context_trans_all[mask_for_task_j]

                    # relabel_reward_all: (num_trans_context, )
                    relabel_reward_all = relabeled_rewards_mean[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    # relabel_reward_all: (num_selected, )
                    relabel_reward_selected = relabel_reward_all[
                        mask_for_task_j]
                    # relabel_reward_all: (num_selected, 1)
                    relabel_reward_selected = relabel_reward_selected.reshape(
                        -1, 1)

                    # relabel_next_obs_all: (num_trans_context, obs_dim)
                    relabel_next_obs_all = relabeled_next_obs_mean[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    # relabel_next_obs_all: (num_selected, obs_dim)
                    relabel_next_obs_selected = relabel_next_obs_all[
                        mask_for_task_j]

                    # context_trans_selected_relabel: (num_selected, context_dim)
                    context_trans_selected_relabel = torch.cat([
                        context_trans_selected[:, :obs_dim + action_dim],
                        relabel_next_obs_selected, relabel_reward_selected
                    ],
                                                               dim=1)

                    # c_{i}
                    ind = np.random.choice(num_trans_context,
                                           num_selected,
                                           replace=False)

                    # Next 2 lines used for comparing to sequential version
                    # ind = ind_list[count]
                    # count += 1

                    # context_trans_task_i: (num_trans_context, context_dim)
                    context_trans_task_i = contexts[i *
                                                    num_trans_context:(i + 1) *
                                                    num_trans_context]
                    # context_trans_task_i: (num_selected, context_dim)
                    context_trans_task_i_sampled = context_trans_task_i[ind]

                    # Pad the contexts with 0 tensor
                    num_to_pad = num_trans_context - num_selected
                    # pad_zero_tensor: (num_to_pad, context_dim)
                    pad_zero_tensor = ptu.zeros(
                        (num_to_pad, context_trans_selected.shape[1]))

                    num_selected_list.append(num_selected)

                    # Dim: (1, num_trans_context, context_dim)
                    context_trans_selected = torch.cat(
                        [context_trans_selected, pad_zero_tensor], dim=0)
                    context_trans_selected_relabel = torch.cat(
                        [context_trans_selected_relabel, pad_zero_tensor],
                        dim=0)
                    context_trans_task_i_sampled = torch.cat(
                        [context_trans_task_i_sampled, pad_zero_tensor], dim=0)

                    anchors.append(context_trans_selected_relabel[None])
                    positives.append(context_trans_task_i_sampled[None])
                    negatives.append(context_trans_selected[None])

        # Dim: (num_tasks * (num_tasks - 1), num_trans_context, context_dim)
        anchors = torch.cat(anchors, dim=0)
        positives = torch.cat(positives, dim=0)
        negatives = torch.cat(negatives, dim=0)

        # input_contexts: (3 * num_tasks * (num_tasks - 1), num_trans_context, context_dim)
        input_contexts = torch.cat([anchors, positives, negatives], dim=0)

        # num_selected_pt: (num_tasks * (num_tasks - 1), )
        num_selected_pt = torch.from_numpy(np.array(num_selected_list))

        # num_selected_repeat: (3 * num_tasks * (num_tasks - 1), )
        num_selected_repeat = num_selected_pt.repeat(3)

        # z_means_vec, z_vars_vec: (3 * num_tasks * (num_tasks - 1), latent_dim)
        z_means_vec, z_vars_vec = self.context_encoder.infer_posterior_with_mean_var(
            input_contexts, num_trans_context, num_selected_repeat)

        # z_means_vec, z_vars_vec: (3, num_tasks * (num_tasks - 1), latent_dim)
        z_means_vec = z_means_vec.reshape(3, anchors.shape[0], -1)
        z_vars_vec = z_vars_vec.reshape(3, anchors.shape[0], -1)

        # Dim: (num_tasks * (num_tasks - 1), latent_dim)
        z_means_anchors, z_vars_anchors = z_means_vec[0], z_vars_vec[0]
        z_means_positives, z_vars_positives = z_means_vec[1], z_vars_vec[1]
        z_means_negatives, z_vars_negatives = z_means_vec[2], z_vars_vec[2]

        with_task_dist = compute_kl_div_diagonal(z_means_anchors,
                                                 z_vars_anchors,
                                                 z_means_positives,
                                                 z_vars_positives)
        across_task_dist = compute_kl_div_diagonal(z_means_anchors,
                                                   z_vars_anchors,
                                                   z_means_negatives,
                                                   z_vars_negatives)

        # Remove the triplet corresponding to
        # num selected equal 0
        exclude_task_masks = ptu.from_numpy(np.array(exclude_task_masks))

        with_task_dist = with_task_dist * exclude_task_masks
        across_task_dist = across_task_dist * exclude_task_masks

        unscaled_triplet_loss_vec = F.relu(with_task_dist - across_task_dist +
                                           self.triplet_margin)
        unscaled_triplet_loss_vec = torch.mean(unscaled_triplet_loss_vec)

        # assert unscaled_triplet_loss_vec is not nan
        assert (unscaled_triplet_loss_vec !=
                unscaled_triplet_loss_vec).any() is not True

        gt.stamp('get_triplet_loss', unique=False)

        unscaled_triplet_loss_vec.backward()

        check_grad_nan_nets(self.networks,
                            f'triplet: {unscaled_triplet_loss_vec}')

        gt.stamp('get_triplet_loss_gradient', unique=False)
        """
        Infer the context variables
        """
        # inferred_mdps = self.context_encoder(new_contexts)
        inferred_mdps = self.context_encoder(fast_contexts)
        inferred_mdps = torch.repeat_interleave(inferred_mdps,
                                                in_mdp_batch_size,
                                                dim=0)

        gt.stamp('infer_mdps', unique=False)
        """
        Obtain the KL loss
        """

        kl_div = self.context_encoder.compute_kl_div()

        kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1)

        kl_loss = torch.sum(kl_loss_each_task)

        gt.stamp('get_kl_loss', unique=False)
        """
        Obtain the Q-function loss
        """
        self.Qs_optimizer.zero_grad()

        pred_q = self.Qs(obs, actions, inferred_mdps)
        pred_q = torch.squeeze(pred_q)

        qf_loss_each_task = (pred_q - target_q)**2
        qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1)
        qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1)

        qf_loss = torch.mean(qf_loss_each_task)

        gt.stamp('get_qf_loss', unique=False)

        (kl_loss + qf_loss).backward()

        check_grad_nan_nets(self.networks, 'kl q')

        gt.stamp('get_kl_qf_gradient', unique=False)

        self.Qs_optimizer.step()
        self.context_encoder_optimizer.step()

        gt.stamp('update_Qs_encoder', unique=False)
        """
        Obtain the candidate action and perturbation loss
        """

        self.vae_decoder_optimizer.zero_grad()
        self.perturbation_generator_optimizer.zero_grad()

        pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach())
        pred_perturbations = self.perturbation_generator(
            obs, target_candidates, inferred_mdps.detach())

        candidate_loss_each_task = (pred_candidates - target_candidates)**2

        # averaging over action dimension
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)
        candidate_loss_each_task = candidate_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)

        candidate_loss = torch.mean(candidate_loss_each_task)

        perturbation_loss_each_task = (pred_perturbations -
                                       target_perturbations)**2

        # average over action dimension
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)
        perturbation_loss_each_task = perturbation_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)

        perturbation_loss = torch.mean(perturbation_loss_each_task)

        gt.stamp('get_candidate_and_perturbation_loss', unique=False)

        (candidate_loss + perturbation_loss).backward()

        check_grad_nan_nets(self.networks, 'perb')

        gt.stamp('get_candidate_and_perturbation_gradient', unique=False)

        self.vae_decoder_optimizer.step()
        self.perturbation_generator_optimizer.step()

        for net in self.networks:

            for name, m in net.named_parameters():

                if (m != m).any():

                    print(net, name)
                    print(num_selected_list)
                    print(min(num_selected_list))

                    exit()

        gt.stamp('update_vae_perturbation', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy(
                qf_loss_each_task)

            self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss))
            self.eval_statistics['triplet_loss'] = np.mean(
                ptu.get_numpy(unscaled_triplet_loss_vec))
            self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy(
                kl_loss_each_task)

            self.eval_statistics['candidate_loss'] = np.mean(
                ptu.get_numpy(candidate_loss))
            self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy(
                candidate_loss_each_task)

            self.eval_statistics['perturbation_loss'] = np.mean(
                ptu.get_numpy(perturbation_loss))
            self.eval_statistics[
                'perturbation_loss_each_task'] = ptu.get_numpy(
                    perturbation_loss_each_task)

            self.eval_statistics[
                'num_context_candidate_each_task'] = num_context_candidate_each_task