def log_diagnostics(self, eval_statistics):
     '''
     adds logging data about encodings to eval_statistics
     '''
     z_mean = np.mean(np.abs(ptu.get_numpy(self.z_means[0])))
     z_sig = np.mean(ptu.get_numpy(self.z_vars[0]))
     eval_statistics['Z mean eval'] = z_mean
     eval_statistics['Z variance eval'] = z_sig
Exemple #2
0
    def train_to_imitate(self, np_batch):
        batch = np_to_pytorch_batch(np_batch)

        obs = batch['observations']
        actions = batch['actions']
        """
        Policy Loss
        """
        _, policy_mean, _, _, policy_std, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )

        policy_var = policy_std**2
        dist = (actions - policy_mean)**2 / policy_var
        dist = torch.sum(dist, dim=1)

        log_policy_var = torch.log(policy_var)
        det = torch.sum(log_policy_var, dim=1)

        policy_loss = torch.mean(dist + det)
        """
        Update networks
        """
        self.policy_imitation_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_imitation_optimizer.step()

        self.eval_statistics['Policy Loss'] = np.mean(
            ptu.get_numpy(policy_loss))
    def select_actions(self, obs, raw_context):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1)))
        if len(raw_context) == 0:
            # In the beginning, the inferred_mdp is set to zero vector.
            inferred_mdp = ptu.zeros((1, self.f.latent_dim))
        else:
            # Construct the context from raw context
            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            inferred_mdp = self.f(context)
        with torch.no_grad():
            inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1)
            z = from_numpy(
                np.random.normal(0, 1, size=(obs.size(0),
                                             self.vae_latent_dim))).clamp(
                                                 -0.5, 0.5).to(ptu.device)
            candidate_actions = self.vae_decoder(obs, z, inferred_mdp)
            perturbed_actions = self.perturbation_generator.get_perturbed_actions(
                obs, candidate_actions, inferred_mdp)
            qv = self.Qs(obs, perturbed_actions, inferred_mdp)
            ind = qv.max(0)[1]
        return ptu.get_numpy(perturbed_actions[ind])
    def select_actions(self, obs, inferred_mdp):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1)))
        with torch.no_grad():
            inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1)
            z = from_numpy(
                np.random.normal(0, 1, size=(obs.size(0),
                                             self.vae_latent_dim))).clamp(
                                                 -0.5, 0.5).to(ptu.device)
            candidate_actions = self.vae_decoder(obs, z, inferred_mdp)
            perturbed_actions = self.perturbation_generator.get_perturbed_actions(
                obs, candidate_actions, inferred_mdp)
            qv = self.Qs(obs, perturbed_actions, inferred_mdp)
            ind = qv.max(0)[1]
        return ptu.get_numpy(perturbed_actions[ind])
Exemple #5
0
    def select_actions(self, obs, raw_context):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        if len(raw_context) == 0:
            # In the beginning, the inferred_mdp is set to zero vector.
            inferred_mdp = ptu.zeros(
                (1, self.policy.mlp_encoder.encoder_latent_dim))
        else:
            # Construct the context from raw context
            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            inferred_mdp = self.policy.mlp_encoder(context)

        # obs = torch.cat([obs, inferred_mdp], dim=1)
        action = self.policy.select_action(obs, get_numpy(inferred_mdp))

        return action
    def check_q_funct_estimate(self, paths):

        s0 = np.stack([path["observations"][0] for path in paths])
        s0 = ptu.from_numpy(s0)

        a0 = np.stack([path["actions"][0] for path in paths])
        a0 = ptu.from_numpy(a0)

        inferred_mdps = torch.repeat_interleave(self.trainer._inferred_mdp,
                                                s0.shape[0],
                                                dim=0)

        q_values = torch.min(
            self.trainer.qf1(s0, a0, inferred_mdps),
            self.trainer.qf2(s0, a0, inferred_mdps),
        )

        q_values = ptu.get_numpy(q_values)

        dicount_returns = []
        for path in paths:

            discount_cof = self.trainer.discount**np.arange(
                len(path["rewards"]))

            dicount_return = np.sum(path["rewards"].flatten() * discount_cof)
            dicount_returns.append(dicount_return)

        q_values_mean = np.mean(q_values)
        q_values_std = np.std(q_values)

        dicount_returns_mean = np.mean(dicount_returns)
        dicount_returns_std = np.std(dicount_returns)

        return dict(
            q_values_mean=q_values_mean,
            q_values_std=q_values_std,
            dicount_returns_mean=dicount_returns_mean,
            dicount_returns_std=dicount_returns_std,
        )
    def train(self, batch, batch_idxes, epoch):
        """
        Unpack data from the batch
        """
        rewards = batch['rewards']
        obs = batch['obs']
        actions = batch['actions']

        # Get the in_mdp_batch_size
        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]
        """
        Obtain the model prediction loss
        """
        # Note that here, we do not calculate the obs_loss.

        pred_rewards = [net(obs, actions) for net in self.network_ensemble]

        # If you would like to train the reward estimator without
        # using the ensemble (Reproduce Fig. 7 in our paper), please
        # comment out Line 62 and uncomment the Line 68 to train only
        # one network to predict the rewards
        # pred_rewards = [self.network_ensemble[0](obs, actions) for net in self.network_ensemble]

        reward_loss_task_0 = [
            F.mse_loss(pred_r[:in_mdp_batch_size], rewards[:in_mdp_batch_size])
            for pred_r in pred_rewards
        ]
        gt.stamp('get_reward_loss', unique=False)

        self.network_ensemble_optimizer.zero_grad()

        [loss.backward() for loss in reward_loss_task_0]

        # Please comment out Line 74 and uncomment Line 78 if you would
        # like to train the reward estimator without using the ensemble
        # reward_loss_task_0[0].backward()

        self.network_ensemble_optimizer.step()

        gt.stamp('update', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:

            if epoch > -1:
                obs_other_tasks = [
                    obs[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)]
                    for i in range(0, batch_idxes.shape[0])
                ]
                actions_other_tasks = [
                    actions[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)]
                    for i in range(0, batch_idxes.shape[0])
                ]
                pred_rewards_other_tasks = [
                    torch.cat([
                        pred_r[in_mdp_batch_size * i:in_mdp_batch_size *
                               (i + 1)] for pred_r in pred_rewards
                    ],
                              dim=1) for i in range(0, batch_idxes.shape[0])
                ]

                reward_loss_other_tasks = []
                reward_loss_other_tasks_std = []
                reward_loss_prop_other_tasks = []
                num_selected_trans_other_tasks = []
                for i, item in enumerate(
                        zip(pred_rewards_other_tasks, obs_other_tasks,
                            actions_other_tasks)):
                    pred_r_other_task, o_other_task, a_other_task = item
                    pred_std = torch.std(pred_r_other_task, dim=1)
                    # print(pred_std)
                    mask = ptu.get_numpy(pred_std < self.std_threshold)
                    num_selected_trans_other_tasks.append(np.sum(mask))

                    mask = mask.astype(bool)
                    pred_r_other_task = ptu.get_numpy(pred_r_other_task)
                    pred_r_record = pred_r_other_task[mask]
                    o_other_task = ptu.get_numpy(o_other_task)
                    o_other_task = o_other_task[mask]
                    a_other_task = ptu.get_numpy(a_other_task)
                    a_other_task = a_other_task[mask]

                    mse_loss = []
                    mse_loss_prop = []

                    for pred_r, o, a in zip(pred_r_record, o_other_task,
                                            a_other_task):
                        if self.domain == 'ant-dir':
                            qpos = np.concatenate([np.zeros(2), o[:13]])
                            qvel = o[13:27]
                        elif self.domain == 'ant-goal':
                            qpos = o[:15]
                            qvel = o[15:29]
                        elif self.domain == 'humanoid-ndone-goal':
                            qpos = o[:24]
                            qvel = o[24:47]
                        elif self.domain == 'humanoid-openai-dir':
                            qpos = np.concatenate([np.zeros(2), o[:22]])
                            qvel = o[22:45]
                        elif self.domain == 'halfcheetah-vel':
                            qpos = np.concatenate([np.zeros(1), o[:8]])
                            qvel = o[8:17]
                        elif 'maze' in self.domain:
                            qpos = o[:2]
                            qvel = o[2:4]

                        self.env.set_state(qpos, qvel)
                        _, r, _, _ = self.env.step(a)
                        mse_loss.append((pred_r - r)**2)
                        mse_loss_prop.append(np.sqrt((pred_r - r)**2 / r**2))

                    if len(mse_loss) > 0:
                        reward_loss_other_tasks.append(
                            np.mean(np.stack(mse_loss), axis=0).tolist())
                        reward_loss_other_tasks_std.append(
                            np.std(np.stack(mse_loss), axis=0).tolist())
                        reward_loss_prop_other_tasks.append(
                            np.mean(np.stack(mse_loss_prop), axis=0).tolist())

                self.eval_statistics[
                    'average_task_reward_loss_other_tasks_mean'] = np.mean(
                        reward_loss_other_tasks, axis=1)
                self.eval_statistics[
                    'average_task_reward_loss_other_tasks_std'] = np.std(
                        reward_loss_other_tasks, axis=1)
                self.eval_statistics[
                    'average_task_reward_loss_prop_other_task'] = np.mean(
                        reward_loss_prop_other_tasks, axis=1)

                self.eval_statistics[
                    'num_selected_trans_other_tasks'] = num_selected_trans_other_tasks

            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['reward_loss_task_0'] = np.mean(
                ptu.get_numpy(torch.mean(torch.stack(reward_loss_task_0))))
    def train(self, batch, batch_idxes):
        """
        Unpack data from the batch
        """
        obs = batch['obs']
        actions = batch['actions']
        contexts = batch['contexts']

        num_tasks = batch_idxes.shape[0]

        gt.stamp('unpack_data_from_the_batch', unique=False)

        # Get the in_mdp_batch_size
        obs_dim = obs.shape[1]
        action_dim = actions.shape[1]
        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]
        num_trans_context = contexts.shape[0] // batch_idxes.shape[0]
        """
        Relabel the context batches for each training task
        """
        with torch.no_grad():
            contexts_obs_actions = contexts[:, :obs_dim + action_dim]

            manual_batched_rewards = self.ensemble_predictor.forward_mul_device(
                contexts_obs_actions)

            relabeled_rewards = manual_batched_rewards.reshape(
                num_tasks, self.num_network_ensemble, contexts.shape[0])

            gt.stamp('relabel_ensemble', unique=False)

            relabeled_rewards_mean = torch.mean(relabeled_rewards,
                                                dim=1).squeeze()
            relabeled_rewards_std = torch.std(relabeled_rewards,
                                              dim=1).squeeze()

            # Replace the predicted reward with ground truth reward for transitions
            # with ground truth reward inside the batch
            for i in range(num_tasks):
                relabeled_rewards_mean[i, i*num_trans_context: (i+1)*num_trans_context] \
                    = contexts[i*num_trans_context: (i+1)*num_trans_context, -1]

                if self.is_combine:
                    # Set the number to be larger than the self.std_threshold, so that
                    # they will initially be filtered out when producing the mask,
                    # which is conducive to the sampling.
                    relabeled_rewards_std[
                        i, i * num_trans_context:(i + 1) *
                        num_trans_context] = self.std_threshold + 1.0
                else:
                    relabeled_rewards_std[i, i * num_trans_context:(i + 1) *
                                          num_trans_context] = 0.0

            mask = relabeled_rewards_std < self.std_threshold

            num_context_candidate_each_task = torch.sum(mask, dim=1)

            mask_list = []

            for i in range(num_tasks):

                assert mask[i].dim() == 1

                mask_nonzero = torch.nonzero(mask[i])
                mask_nonzero = mask_nonzero.flatten()

                mask_i = ptu.zeros_like(mask[i], dtype=torch.uint8)

                assert num_context_candidate_each_task[i].item(
                ) == mask_nonzero.shape[0]

                np_ind = np.random.choice(mask_nonzero.shape[0],
                                          num_trans_context,
                                          replace=False)

                ind = mask_nonzero[np_ind]

                mask_i[ind] = 1

                if self.is_combine:
                    # Combine the additional relabeledcontext transitions with
                    # the original context transitions with ground-truth rewards
                    mask_i[i * num_trans_context:(i + 1) *
                           num_trans_context] = 1

                    assert torch.sum(mask_i).item() == 2 * num_trans_context
                else:
                    assert torch.sum(mask_i).item() == num_trans_context

                mask_list.append(mask_i)

            mask = torch.cat(mask_list)
            mask = mask.type(torch.uint8)

            repeated_contexts = contexts.repeat(num_tasks, 1)
            context_without_rewards = repeated_contexts[:, :-1]

            assert context_without_rewards.shape[
                0] == relabeled_rewards_mean.reshape(-1, 1).shape[0]

            context_without_rewards = context_without_rewards[mask]

            context_rewards = relabeled_rewards_mean.reshape(-1, 1)[mask]

            fast_contexts = torch.cat(
                (context_without_rewards, context_rewards), dim=1)

            fast_contexts = fast_contexts.reshape(num_tasks, -1,
                                                  contexts.shape[-1])

        gt.stamp('relabel_context_transitions', unique=False)
        """
        Infer the context variables
        """
        # inferred_mdps = self.context_encoder(new_contexts)
        inferred_mdps = self.context_encoder(fast_contexts)
        inferred_mdps = torch.repeat_interleave(inferred_mdps,
                                                in_mdp_batch_size,
                                                dim=0)

        gt.stamp('infer_mdps', unique=False)

        with torch.no_grad():
            # Sample z for each state
            z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device)

            # Each item in critic_weights is a list that has device count entries
            # each entry in the critic_weights[i] is a list that has num layer entries
            # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size)
            # Similarly to the other weights and biases
            critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies

            # CRITIC
            obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size,
                                           -1)

            obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1)

            target_q = batch_bcq(obs_acs_reshaped, critic_weights,
                                 critic_biases)
            target_q = target_q.reshape(-1)

            # VAE
            z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1)

            tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases)
            tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc)
            target_candidates = tc.reshape(-1, tc.shape[-1])

            # PERTURBATION
            tc_reshaped = target_candidates.reshape(len(batch_idxes),
                                                    in_mdp_batch_size, -1)

            obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1)

            tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases)
            tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp)
            target_perturbations = tp.reshape(-1, tp.shape[-1])

        gt.stamp('get_the_targets', unique=False)
        """
        Obtain the KL loss
        """

        # KL constraint on z if probabilistic
        self.context_encoder_optimizer.zero_grad()

        kl_div = self.context_encoder.compute_kl_div()

        kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1)

        kl_loss = torch.sum(kl_loss_each_task)

        gt.stamp('get_kl_loss', unique=False)
        """
        Obtain the Q-function loss
        """
        self.Qs_optimizer.zero_grad()

        pred_q = self.Qs(obs, actions, inferred_mdps)
        pred_q = torch.squeeze(pred_q)

        qf_loss_each_task = (pred_q - target_q)**2
        qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1)
        qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1)

        qf_loss = torch.mean(qf_loss_each_task)

        gt.stamp('get_qf_loss', unique=False)

        (kl_loss + qf_loss).backward()

        gt.stamp('get_kl_qf_gradient', unique=False)

        self.Qs_optimizer.step()
        self.context_encoder_optimizer.step()

        gt.stamp('update_Qs_encoder', unique=False)
        """
        Obtain the candidate action and perturbation loss
        """

        self.vae_decoder_optimizer.zero_grad()
        self.perturbation_generator_optimizer.zero_grad()

        pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach())
        pred_perturbations = self.perturbation_generator(
            obs, target_candidates, inferred_mdps.detach())

        candidate_loss_each_task = (pred_candidates - target_candidates)**2

        # averaging over action dimension
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)
        candidate_loss_each_task = candidate_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)

        candidate_loss = torch.mean(candidate_loss_each_task)

        perturbation_loss_each_task = (pred_perturbations -
                                       target_perturbations)**2

        # average over action dimension
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)
        perturbation_loss_each_task = perturbation_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)

        perturbation_loss = torch.mean(perturbation_loss_each_task)

        gt.stamp('get_candidate_and_perturbation_loss', unique=False)

        (candidate_loss + perturbation_loss).backward()

        gt.stamp('get_candidate_and_perturbation_gradient', unique=False)

        self.vae_decoder_optimizer.step()
        self.perturbation_generator_optimizer.step()

        gt.stamp('update_vae_perturbation', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy(
                qf_loss_each_task)

            self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss))
            self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy(
                kl_loss_each_task)

            self.eval_statistics['candidate_loss'] = np.mean(
                ptu.get_numpy(candidate_loss))
            self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy(
                candidate_loss_each_task)

            self.eval_statistics['perturbation_loss'] = np.mean(
                ptu.get_numpy(perturbation_loss))
            self.eval_statistics[
                'perturbation_loss_each_task'] = ptu.get_numpy(
                    perturbation_loss_each_task)

            self.eval_statistics[
                'num_context_candidate_each_task'] = num_context_candidate_each_task
    def train_from_torch_qf2_policy(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        QF Loss
        """
        if self.use_automatic_entropy_tuning:
            alpha = self.log_alpha.exp()
        else:
            alpha = 1

        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing
        # functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs,
            reparameterize=True,
            return_log_prob=True,
        )
        target_q2_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q2_target = self.reward_scale * rewards + \
            (1. - terminals) * self.discount * target_q2_values
        qf2_loss = self.qf_criterion(q2_pred, q2_target.detach())
        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )

        policy_loss = (alpha * log_pi - q_new_actions).mean()
        """
        Update networks
        """
        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau_qf)

            ptu.soft_update_from_to(self.policy, self.target_policy,
                                    self.soft_target_tau_policy)
            self.policy.load_state_dict(self.target_policy.state_dict())
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))

            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Targets',
                    ptu.get_numpy(q2_target),
                ))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
    def train_from_torch_qf1(self, batch):

        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        if self.use_automatic_entropy_tuning:
            alpha = self.log_alpha.exp()
        else:
            alpha = 1
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)

        # Make sure policy accounts for squashing
        # functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs,
            reparameterize=True,
            return_log_prob=True,
        )
        target_q1_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q1_target = self.reward_scale * rewards + \
            (1. - terminals) * self.discount * target_q1_values

        qf1_loss = self.qf_criterion(q1_pred, q1_target.detach())
        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau_qf)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))

            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Targets',
                    ptu.get_numpy(q1_target),
                ))

        self._n_train_steps_total += 1
Exemple #11
0
def np_ify(tensor_or_other):
    if isinstance(tensor_or_other, torch.autograd.Variable):
        return ptu.get_numpy(tensor_or_other)
    else:
        return tensor_or_other
Exemple #12
0
    def train(self, train_data, discount=0.99, tau=0.005):

        # Sample replay buffer / batch
        state_np, next_state_np, action, reward, done, context = train_data
        state = torch.FloatTensor(state_np).to(device)
        action = torch.FloatTensor(action).to(device)
        next_state = torch.FloatTensor(next_state_np).to(device)
        reward = torch.FloatTensor(reward).to(device)
        done = torch.FloatTensor(1 - done).to(device)
        context = torch.FloatTensor(context).to(device)

        gt.stamp('unpack_data', unique=False)

        # Infer mdep identity using context
        # inferred_mdp = self.mlp_encoder(context)
        # in_mdp_batch_size = state.shape[0] // context.shape[0]
        # inferred_mdp = torch.repeat_interleave(inferred_mdp, in_mdp_batch_size, dim=0)

        # gt.stamp('infer_mdp_identity', unique=False)

        # Train the mlp encoder to predict the rewards.
        # self.mlp_encoder.zero_grad()
        # pred_next_obs = self.E(state, action)
        # pred_rewards = self.P(pred_next_obs, inferred_mdp)
        # reward_loss = F.mse_loss(pred_rewards, reward)

        # gt.stamp('get_reward_loss', unique=False)

        # reward_loss.backward(retain_graph=True)

        # gt.stamp('get_reward_gradient', unique=False)

        # Extend the state space using the inferred_mdp
        # state = torch.cat([state, inferred_mdp], dim=1)
        # next_state = torch.cat([next_state, inferred_mdp], dim=1)

        # gt.stamp('extend_original_state', unique=False)

        # Critic Training
        self.critic_optimizer.zero_grad()
        with torch.no_grad():

            # Duplicate state 10 times
            state_rep = next_state.repeat_interleave(10, dim=0)
            gt.stamp('check0', unique=False)

            # candidate_action = self.vae.decode(state_rep)
            # torch.cuda.synchronize()
            # gt.stamp('check1', unique=False)

            # perturbated_action = self.actor_target(state_rep, candidate_action)
            # torch.cuda.synchronize()
            # gt.stamp('check2', unique=False)

            # target_Q1, target_Q2 = self.critic_target(state_rep, perturbated_action)
            # torch.cuda.synchronize()
            # gt.stamp('check3', unique=False)

            target_Q1, target_Q2 = self.critic_target(
                state_rep,
                self.actor_target(state_rep, self.vae.decode(state_rep)))

            # Soft Clipped Double Q-learning
            target_Q = self.target_q_coef * torch.min(target_Q1, target_Q2) + (
                1 - self.target_q_coef) * torch.max(target_Q1, target_Q2)

            target_Q = target_Q.view(state.shape[0], -1).max(1)[0].view(-1, 1)

            target_Q = reward + done * discount * target_Q

        current_Q1, current_Q2 = self.critic(state, action)

        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)

        gt.stamp('get_critic_loss', unique=False)

        critic_loss.backward()  # retain_graph=True
        gt.stamp('get_critic_gradient', unique=False)

        # self.mlp_encoder_optimizer.step()
        # gt.stamp('update_mlp_encoder', unique=False)

        # Variational Auto-Encoder Training
        recon, mean, std = self.vae(state, action)
        recon_loss = F.mse_loss(recon, action)
        KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                          std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * KL_loss

        gt.stamp('get_vae_loss', unique=False)

        self.vae_optimizer.zero_grad()
        vae_loss.backward()
        self.vae_optimizer.step()

        gt.stamp('update_vae', unique=False)

        self.critic_optimizer.step()
        gt.stamp('update_critic', unique=False)

        # Pertubation Model / Action Training
        sampled_actions = self.vae.decode(state)
        perturbed_actions = self.actor(state, sampled_actions)

        # Update through DPG
        self.actor_optimizer.zero_grad()
        actor_loss = -self.critic.q1(state, perturbed_actions).mean()

        gt.stamp('get_actor_loss', unique=False)

        self.actor_optimizer.step()
        gt.stamp('update_actor', unique=False)

        # Update Target Networks
        for param, target_param in zip(self.critic.parameters(),
                                       self.critic_target.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)

        for param, target_param in zip(self.actor.parameters(),
                                       self.actor_target.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)

        gt.stamp('update_target_network', unique=False)
        """
		Save some statistics for eval
		"""
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
			Eval should set this to None.
			This way, these statistics are only computed for one batch.
			"""
            self.eval_statistics['actor_loss'] = np.mean(get_numpy(actor_loss))
            self.eval_statistics['critic_loss'] = np.mean(
                get_numpy(critic_loss))
            self.eval_statistics['vae_loss'] = np.mean(get_numpy(vae_loss))
    def train(self, batch, batch_idxes):
        """
        Unpack data from the batch
        """
        obs = batch['obs']
        actions = batch['actions']
        contexts = batch['contexts']

        num_candidate_context = contexts[0].shape[0]
        meta_batch_size = batch_idxes.shape[0]
        num_posterior = meta_batch_size * num_candidate_context
        contexts = torch.cat(contexts, dim=0)

        # Get the in_mdp_batch_size
        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]

        # Sample z for each state
        z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device)

        target_q = []
        target_candidates = []
        target_perturbations = []

        for i, batch_idx in enumerate(batch_idxes):

            tq = self.bcq_polices[batch_idx].critic.q1(
                obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size],
                actions[i * in_mdp_batch_size:(i + 1) *
                        in_mdp_batch_size]).detach()
            target_q.append(tq)

            tc = self.bcq_polices[batch_idx].vae.decode(
                obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size],
                z[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size]).detach()
            target_candidates.append(tc)

            tp = self.bcq_polices[batch_idx].get_perturbation(
                obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size],
                tc).detach()
            target_perturbations.append(tp)

        target_q = torch.cat(target_q, dim=0).squeeze()
        target_candidates = torch.cat(target_candidates, dim=0)
        target_perturbations = torch.cat(target_perturbations, dim=0)

        gt.stamp('get_the_targets', unique=False)
        """
        Compute triplet loss
        """
        self.context_encoder_optimizer.zero_grad()

        # z_means, z_var: (num_posterior, latent_dim), num_posterior = meta_batch_size * num_candidate_context
        z_means, z_vars = self.context_encoder.infer_posterior_with_mean_var(
            contexts)

        # z_means_interleave: (num_posterior * num_posterior, latent_dim) [1, 2, 3] -> [1, 1, 1, 2, 2, 2, 3, 3, 3]
        z_means_interleave = torch.repeat_interleave(z_means,
                                                     num_posterior,
                                                     dim=0)
        # z_means_repeat: (num_posterior * num_posterior, latent_dim) [1, 2, 3] -> [1, 2, 3, 2, 3, 1, 3, 1, 2].
        # By doing so, it is easy to get the triplet loss
        z_means_repeat = []
        for i in range(meta_batch_size):
            z_means_repeat.append(
                torch.cat([
                    z_means[i * num_candidate_context:],
                    z_means[:i * num_candidate_context]
                ],
                          dim=0).repeat(num_candidate_context, 1))
        z_means_repeat = torch.cat(z_means_repeat, dim=0)

        # As above
        z_vars_interleave = torch.repeat_interleave(z_vars,
                                                    num_posterior,
                                                    dim=0)
        z_vars_repeat = []
        for i in range(meta_batch_size):
            z_vars_repeat.append(
                torch.cat([
                    z_vars[i * num_candidate_context:],
                    z_vars[:i * num_candidate_context]
                ],
                          dim=0).repeat(num_candidate_context, 1))
        z_vars_repeat = torch.cat(z_vars_repeat, dim=0)

        gt.stamp('get_repeated_mean_var', unique=False)

        # log(det(Sigma2) / det(Sigma1)): (num_posterior * num_posterior, 1)
        kl_divergence = torch.log(
            torch.prod(z_vars_repeat / z_vars_interleave, dim=1))
        # -d
        kl_divergence -= z_means.shape[-1]
        # Tr(Sigma2^{-1} * Sigma1)
        kl_divergence += torch.sum(z_vars_interleave / z_vars_repeat, dim=1)
        # (m2 - m1).T Sigma2^{-1} (m2 - m1))
        kl_divergence += torch.sum(
            (z_means_repeat - z_means_interleave)**2 / z_vars_repeat, dim=1)
        # / 2
        # (num_posterior, num_posterior): each element kl_{i, j} denotes the kl divergence between the two distributions.
        # Task number for row: i // num_posterior // num_candidate_context.
        #             for col: j % num_posterior // num_candidate_context.
        # Batch number for row: i // num_posterior % num_candidate_context.
        #              for col: j % num_posterior % num_candidate_context.
        kl_divergence = kl_divergence.reshape(num_posterior, num_posterior) / 2

        within_task_dist = torch.max(kl_divergence[:, :num_candidate_context],
                                     dim=1)[0]
        across_task_dist = torch.min(kl_divergence[:, num_candidate_context:],
                                     dim=1)[0]

        unscaled_triplet_loss = torch.sum(
            F.relu(within_task_dist - across_task_dist + self.triplet_margin))

        gt.stamp('get_triplet_loss', unique=False)
        """
        Infer the context variables
        """
        index = np.random.choice(
            num_candidate_context, meta_batch_size
        ) + num_candidate_context * np.arange(meta_batch_size)
        # Get the sampled mean and vars for each task.
        # mean: (meta_batch_size, latent_dim)
        # var: (meta_batch_size, latent_dim)
        mean = z_means[index]
        var = z_vars[index]

        # Get the inferred MDP
        # inferred_mdps: (meta_batch_size, latent_dim)
        inferred_mdps = self.context_encoder.sample_z_from_mean_var(mean, var)

        inferred_mdps = torch.repeat_interleave(inferred_mdps,
                                                in_mdp_batch_size,
                                                dim=0)

        gt.stamp('infer_mdps', unique=False)
        """
        Obtain the KL loss
        """
        prior_mean = ptu.zeros(mean.shape)
        prior_var = ptu.ones(var.shape)

        kl_loss = self.kl_lambda * self.context_encoder.compute_kl_div_between_posterior(
            mean, var, prior_mean, prior_var)

        gt.stamp('get_kl_loss', unique=False)

        # triplet_loss = (kl_loss / unscaled_triplet_loss).detach() * unscaled_triplet_loss
        # posterior_loss = unscaled_triplet_loss + kl_loss
        # posterior_loss.backward(retain_graph=True)

        # gt.stamp('get_posterior_gradient', unique=False)
        """
        Obtain the Q-function loss
        """
        self.Qs_optimizer.zero_grad()
        pred_q = self.Qs(obs, actions, inferred_mdps)
        pred_q = torch.squeeze(pred_q)
        qf_loss = F.mse_loss(pred_q, target_q)

        gt.stamp('get_qf_loss', unique=False)

        (qf_loss + unscaled_triplet_loss + kl_loss).backward()

        gt.stamp('get_qf_encoder_gradient', unique=False)

        self.Qs_optimizer.step()
        self.context_encoder_optimizer.step()
        """
        Obtain the candidate action and perturbation loss
        """

        self.vae_decoder_optimizer.zero_grad()
        self.perturbation_generator_optimizer.zero_grad()

        pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach())
        pred_perturbations = self.perturbation_generator(
            obs, target_candidates, inferred_mdps.detach())

        candidate_loss = F.mse_loss(pred_candidates, target_candidates)
        perturbation_loss = F.mse_loss(pred_perturbations,
                                       target_perturbations)

        gt.stamp('get_candidate_and_perturbation_loss', unique=False)

        candidate_loss.backward()
        perturbation_loss.backward()

        gt.stamp('get_candidate_and_perturbation_gradient', unique=False)

        self.vae_decoder_optimizer.step()
        self.perturbation_generator_optimizer.step()
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['unscaled_triplet_loss'] = np.mean(
                ptu.get_numpy(unscaled_triplet_loss))
            self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss))
            self.eval_statistics['candidate_loss'] = np.mean(
                ptu.get_numpy(candidate_loss))
            self.eval_statistics['perturbation_loss'] = np.mean(
                ptu.get_numpy(perturbation_loss))
Exemple #14
0
def np_ify(tensor_or_other):
    if isinstance(tensor_or_other, Variable):
        return ptu.get_numpy(tensor_or_other)
    else:
        return tensor_or_other
Exemple #15
0
    def select_actions(self, obs, inferred_mdp):
        action = self.policy.select_action(obs, get_numpy(inferred_mdp))

        return action
Exemple #16
0
    def train(self, batch, batch_idxes):
        """
        Unpack data from the batch
        """
        obs = batch['obs']
        actions = batch['actions']
        contexts = batch['contexts']

        num_tasks = batch_idxes.shape[0]

        gt.stamp('unpack_data_from_the_batch', unique=False)

        # Get the in_mdp_batch_size
        obs_dim = obs.shape[1]
        action_dim = actions.shape[1]
        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]
        num_trans_context = contexts.shape[0] // batch_idxes.shape[0]
        """
        Relabel the context batches for each training task
        """
        with torch.no_grad():
            contexts_obs_actions = contexts[:, :obs_dim + action_dim]

            manual_batched_rewards = self.reward_ensemble_predictor.forward_mul_device(
                contexts_obs_actions)

            relabeled_rewards = manual_batched_rewards.reshape(
                num_tasks, self.num_network_ensemble, contexts.shape[0])

            gt.stamp('reward_ensemble_forward', unique=False)

            manual_batched_next_obs = self.transition_ensemble_predictor.forward_mul_device(
                contexts_obs_actions)

            relabeled_next_obs = manual_batched_next_obs.reshape(
                num_tasks, self.num_network_ensemble, contexts.shape[0],
                obs_dim)

            gt.stamp('transition_ensemble_forward', unique=False)

            relabeled_rewards_mean = torch.mean(relabeled_rewards,
                                                dim=1).squeeze()
            relabeled_rewards_std = torch.std(relabeled_rewards,
                                              dim=1).squeeze()

            relabeled_next_obs_mean = torch.mean(relabeled_next_obs, dim=1)

            relabeled_next_obs_std = torch.std(relabeled_next_obs, dim=1)
            relabeled_next_obs_std = torch.mean(relabeled_next_obs_std, dim=-1)

            # Replace the predicted reward with ground truth reward for transitions
            # with ground truth reward inside the batch
            for i in range(num_tasks):

                relabeled_rewards_mean[i, i*num_trans_context: (i+1)*num_trans_context] \
                    = contexts[i*num_trans_context: (i+1)*num_trans_context, -1]

                relabeled_next_obs_mean[i, i*num_trans_context: (i+1)*num_trans_context, :] \
                    = contexts[i*num_trans_context: (i+1)*num_trans_context, obs_dim + action_dim : -1]

                if self.is_combine:
                    # Set the number to be larger than the self.std_threshold, so that
                    # they will initially be filtered out when producing the mask,
                    # which is conducive to the sampling.
                    relabeled_rewards_std[
                        i, i * num_trans_context:(i + 1) *
                        num_trans_context] = self.reward_std_threshold + 1.0
                    relabeled_next_obs_std[
                        i, i * num_trans_context:(i + 1) *
                        num_trans_context] = self.next_obs_std_threshold + 1.0
                else:
                    relabeled_rewards_std[i, i * num_trans_context:(i + 1) *
                                          num_trans_context] = 0.0
                    relabeled_next_obs_std[i, i * num_trans_context:(i + 1) *
                                           num_trans_context] = 0.0

            mask_reward = relabeled_rewards_std < self.reward_std_threshold
            mask_reward = mask_reward.type(torch.float)

            mask_next_obs = relabeled_next_obs_std < self.next_obs_std_threshold
            mask_next_obs = mask_next_obs.type(torch.float)

            mask = mask_reward * mask_next_obs
            mask = mask.type(torch.uint8)

            mask_from_the_other_tasks = mask.type(torch.uint8).clone()

            num_context_candidate_each_task = torch.sum(mask, dim=1)

            mask_list = []

            for i in range(num_tasks):

                assert mask[i].dim() == 1

                mask_nonzero = torch.nonzero(mask[i])
                mask_nonzero = mask_nonzero.flatten()

                mask_i = ptu.zeros_like(mask[i], dtype=torch.uint8)

                assert num_context_candidate_each_task[i].item(
                ) == mask_nonzero.shape[0]

                np_ind = np.random.choice(mask_nonzero.shape[0],
                                          num_trans_context,
                                          replace=False)

                ind = mask_nonzero[np_ind]

                mask_i[ind] = 1

                if self.is_combine:
                    # Combine the additional relabeledcontext transitions with
                    # the original context transitions with ground-truth rewards
                    mask_i[i * num_trans_context:(i + 1) *
                           num_trans_context] = 1
                    assert torch.sum(mask_i).item() == 2 * num_trans_context
                else:
                    assert torch.sum(mask_i).item() == num_trans_context

                mask_list.append(mask_i)

            mask = torch.cat(mask_list)
            mask = mask.type(torch.uint8)

            repeated_contexts = contexts.repeat(num_tasks, 1)
            context_without_next_obs_rewards = repeated_contexts[:, :obs_dim +
                                                                 action_dim]

            assert context_without_next_obs_rewards.shape[
                0] == relabeled_rewards_mean.reshape(-1, 1).shape[0]
            assert context_without_next_obs_rewards.shape[
                0] == relabeled_next_obs_mean.reshape(-1, obs_dim).shape[0]

            context_without_next_obs_rewards = context_without_next_obs_rewards[
                mask]

            context_next_obs = relabeled_next_obs_mean.reshape(-1,
                                                               obs_dim)[mask]

            context_rewards = relabeled_rewards_mean.reshape(-1, 1)[mask]

            fast_contexts = torch.cat((context_without_next_obs_rewards,
                                       context_next_obs, context_rewards),
                                      dim=1)

            fast_contexts = fast_contexts.reshape(num_tasks, -1,
                                                  contexts.shape[-1])

        gt.stamp('relabel_context_transitions', unique=False)
        """
        Obtain the targets
        """
        with torch.no_grad():
            # Sample z for each state
            z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device)

            # Each item in critic_weights is a list that has device count entries
            # each entry in the critic_weights[i] is a list that has num layer entries
            # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size)
            # Similarly to the other weights and biases
            critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies

            # CRITIC
            obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size,
                                           -1)

            obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1)

            target_q = batch_bcq(obs_acs_reshaped, critic_weights,
                                 critic_biases)
            target_q = target_q.reshape(-1)

            # VAE
            z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1)

            tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases)
            tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc)
            target_candidates = tc.reshape(-1, tc.shape[-1])

            # PERTURBATION
            tc_reshaped = target_candidates.reshape(len(batch_idxes),
                                                    in_mdp_batch_size, -1)

            obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1)

            tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases)
            tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp)
            target_perturbations = tp.reshape(-1, tp.shape[-1])

        gt.stamp('get_the_targets', unique=False)
        """
        Compute the triplet loss
        """
        # ----------------------------------Vectorized-------------------------------------------
        self.context_encoder_optimizer.zero_grad()

        anchors = []
        positives = []
        negatives = []

        num_selected_list = []

        # Pair of task (i,j)
        # where no transitions from j is selected by the ensemble of task i
        exclude_tasks = []

        exclude_task_masks = []

        for i in range(num_tasks):
            # Compute the triplet loss for task i

            for j in range(num_tasks):

                if j != i:

                    # mask_for_task_j: (num_trans_context, )
                    # mask_from_the_other_tasks: (num_tasks, num_tasks * num_trans_context)
                    mask_for_task_j = mask_from_the_other_tasks[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    num_selected = int(torch.sum(mask_for_task_j).item())

                    if num_selected == 0:
                        exclude_tasks.append((i, j))
                        exclude_task_masks.append(0)
                    else:
                        exclude_task_masks.append(1)

                    # context_trans_all: (num_trans_context, context_dim)
                    context_trans_all = contexts[j *
                                                 num_trans_context:(j + 1) *
                                                 num_trans_context]
                    # context_trans_all: (num_selected, context_dim)
                    context_trans_selected = context_trans_all[mask_for_task_j]

                    # relabel_reward_all: (num_trans_context, )
                    relabel_reward_all = relabeled_rewards_mean[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    # relabel_reward_all: (num_selected, )
                    relabel_reward_selected = relabel_reward_all[
                        mask_for_task_j]
                    # relabel_reward_all: (num_selected, 1)
                    relabel_reward_selected = relabel_reward_selected.reshape(
                        -1, 1)

                    # relabel_next_obs_all: (num_trans_context, obs_dim)
                    relabel_next_obs_all = relabeled_next_obs_mean[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    # relabel_next_obs_all: (num_selected, obs_dim)
                    relabel_next_obs_selected = relabel_next_obs_all[
                        mask_for_task_j]

                    # context_trans_selected_relabel: (num_selected, context_dim)
                    context_trans_selected_relabel = torch.cat([
                        context_trans_selected[:, :obs_dim + action_dim],
                        relabel_next_obs_selected, relabel_reward_selected
                    ],
                                                               dim=1)

                    # c_{i}
                    ind = np.random.choice(num_trans_context,
                                           num_selected,
                                           replace=False)

                    # Next 2 lines used for comparing to sequential version
                    # ind = ind_list[count]
                    # count += 1

                    # context_trans_task_i: (num_trans_context, context_dim)
                    context_trans_task_i = contexts[i *
                                                    num_trans_context:(i + 1) *
                                                    num_trans_context]
                    # context_trans_task_i: (num_selected, context_dim)
                    context_trans_task_i_sampled = context_trans_task_i[ind]

                    # Pad the contexts with 0 tensor
                    num_to_pad = num_trans_context - num_selected
                    # pad_zero_tensor: (num_to_pad, context_dim)
                    pad_zero_tensor = ptu.zeros(
                        (num_to_pad, context_trans_selected.shape[1]))

                    num_selected_list.append(num_selected)

                    # Dim: (1, num_trans_context, context_dim)
                    context_trans_selected = torch.cat(
                        [context_trans_selected, pad_zero_tensor], dim=0)
                    context_trans_selected_relabel = torch.cat(
                        [context_trans_selected_relabel, pad_zero_tensor],
                        dim=0)
                    context_trans_task_i_sampled = torch.cat(
                        [context_trans_task_i_sampled, pad_zero_tensor], dim=0)

                    anchors.append(context_trans_selected_relabel[None])
                    positives.append(context_trans_task_i_sampled[None])
                    negatives.append(context_trans_selected[None])

        # Dim: (num_tasks * (num_tasks - 1), num_trans_context, context_dim)
        anchors = torch.cat(anchors, dim=0)
        positives = torch.cat(positives, dim=0)
        negatives = torch.cat(negatives, dim=0)

        # input_contexts: (3 * num_tasks * (num_tasks - 1), num_trans_context, context_dim)
        input_contexts = torch.cat([anchors, positives, negatives], dim=0)

        # num_selected_pt: (num_tasks * (num_tasks - 1), )
        num_selected_pt = torch.from_numpy(np.array(num_selected_list))

        # num_selected_repeat: (3 * num_tasks * (num_tasks - 1), )
        num_selected_repeat = num_selected_pt.repeat(3)

        # z_means_vec, z_vars_vec: (3 * num_tasks * (num_tasks - 1), latent_dim)
        z_means_vec, z_vars_vec = self.context_encoder.infer_posterior_with_mean_var(
            input_contexts, num_trans_context, num_selected_repeat)

        # z_means_vec, z_vars_vec: (3, num_tasks * (num_tasks - 1), latent_dim)
        z_means_vec = z_means_vec.reshape(3, anchors.shape[0], -1)
        z_vars_vec = z_vars_vec.reshape(3, anchors.shape[0], -1)

        # Dim: (num_tasks * (num_tasks - 1), latent_dim)
        z_means_anchors, z_vars_anchors = z_means_vec[0], z_vars_vec[0]
        z_means_positives, z_vars_positives = z_means_vec[1], z_vars_vec[1]
        z_means_negatives, z_vars_negatives = z_means_vec[2], z_vars_vec[2]

        with_task_dist = compute_kl_div_diagonal(z_means_anchors,
                                                 z_vars_anchors,
                                                 z_means_positives,
                                                 z_vars_positives)
        across_task_dist = compute_kl_div_diagonal(z_means_anchors,
                                                   z_vars_anchors,
                                                   z_means_negatives,
                                                   z_vars_negatives)

        # Remove the triplet corresponding to
        # num selected equal 0
        exclude_task_masks = ptu.from_numpy(np.array(exclude_task_masks))

        with_task_dist = with_task_dist * exclude_task_masks
        across_task_dist = across_task_dist * exclude_task_masks

        unscaled_triplet_loss_vec = F.relu(with_task_dist - across_task_dist +
                                           self.triplet_margin)
        unscaled_triplet_loss_vec = torch.mean(unscaled_triplet_loss_vec)

        # assert unscaled_triplet_loss_vec is not nan
        assert (unscaled_triplet_loss_vec !=
                unscaled_triplet_loss_vec).any() is not True

        gt.stamp('get_triplet_loss', unique=False)

        unscaled_triplet_loss_vec.backward()

        check_grad_nan_nets(self.networks,
                            f'triplet: {unscaled_triplet_loss_vec}')

        gt.stamp('get_triplet_loss_gradient', unique=False)
        """
        Infer the context variables
        """
        # inferred_mdps = self.context_encoder(new_contexts)
        inferred_mdps = self.context_encoder(fast_contexts)
        inferred_mdps = torch.repeat_interleave(inferred_mdps,
                                                in_mdp_batch_size,
                                                dim=0)

        gt.stamp('infer_mdps', unique=False)
        """
        Obtain the KL loss
        """

        kl_div = self.context_encoder.compute_kl_div()

        kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1)

        kl_loss = torch.sum(kl_loss_each_task)

        gt.stamp('get_kl_loss', unique=False)
        """
        Obtain the Q-function loss
        """
        self.Qs_optimizer.zero_grad()

        pred_q = self.Qs(obs, actions, inferred_mdps)
        pred_q = torch.squeeze(pred_q)

        qf_loss_each_task = (pred_q - target_q)**2
        qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1)
        qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1)

        qf_loss = torch.mean(qf_loss_each_task)

        gt.stamp('get_qf_loss', unique=False)

        (kl_loss + qf_loss).backward()

        check_grad_nan_nets(self.networks, 'kl q')

        gt.stamp('get_kl_qf_gradient', unique=False)

        self.Qs_optimizer.step()
        self.context_encoder_optimizer.step()

        gt.stamp('update_Qs_encoder', unique=False)
        """
        Obtain the candidate action and perturbation loss
        """

        self.vae_decoder_optimizer.zero_grad()
        self.perturbation_generator_optimizer.zero_grad()

        pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach())
        pred_perturbations = self.perturbation_generator(
            obs, target_candidates, inferred_mdps.detach())

        candidate_loss_each_task = (pred_candidates - target_candidates)**2

        # averaging over action dimension
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)
        candidate_loss_each_task = candidate_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)

        candidate_loss = torch.mean(candidate_loss_each_task)

        perturbation_loss_each_task = (pred_perturbations -
                                       target_perturbations)**2

        # average over action dimension
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)
        perturbation_loss_each_task = perturbation_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)

        perturbation_loss = torch.mean(perturbation_loss_each_task)

        gt.stamp('get_candidate_and_perturbation_loss', unique=False)

        (candidate_loss + perturbation_loss).backward()

        check_grad_nan_nets(self.networks, 'perb')

        gt.stamp('get_candidate_and_perturbation_gradient', unique=False)

        self.vae_decoder_optimizer.step()
        self.perturbation_generator_optimizer.step()

        for net in self.networks:

            for name, m in net.named_parameters():

                if (m != m).any():

                    print(net, name)
                    print(num_selected_list)
                    print(min(num_selected_list))

                    exit()

        gt.stamp('update_vae_perturbation', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy(
                qf_loss_each_task)

            self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss))
            self.eval_statistics['triplet_loss'] = np.mean(
                ptu.get_numpy(unscaled_triplet_loss_vec))
            self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy(
                kl_loss_each_task)

            self.eval_statistics['candidate_loss'] = np.mean(
                ptu.get_numpy(candidate_loss))
            self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy(
                candidate_loss_each_task)

            self.eval_statistics['perturbation_loss'] = np.mean(
                ptu.get_numpy(perturbation_loss))
            self.eval_statistics[
                'perturbation_loss_each_task'] = ptu.get_numpy(
                    perturbation_loss_each_task)

            self.eval_statistics[
                'num_context_candidate_each_task'] = num_context_candidate_each_task
Exemple #17
0
 def get_param_values_np(self):
     state_dict = self.state_dict()
     np_dict = OrderedDict()
     for key, tensor in state_dict.items():
         np_dict[key] = ptu.get_numpy(tensor)
     return np_dict
    def train(self, train_data, discount=0.99, tau=0.005):
        state_np, next_state_np, action, reward, done, context = train_data
        state = torch.FloatTensor(state_np).to(device)
        action = torch.FloatTensor(action).to(device)
        next_state = torch.FloatTensor(next_state_np).to(device)
        reward = torch.FloatTensor(reward).to(device)
        done = torch.FloatTensor(1 - done).to(device)
        context = torch.FloatTensor(context).to(device)

        gt.stamp('unpack_data', unique=False)

        # Infer mdep identity using context

        self.context_encoder_optimizer.zero_grad()

        inferred_mdp = self.context_encoder(context)
        in_mdp_batch_size = state.shape[0] // context.shape[0]
        inferred_mdp = torch.repeat_interleave(inferred_mdp,
                                               in_mdp_batch_size,
                                               dim=0)

        gt.stamp('infer_mdp_identity', unique=False)

        # Variational Auto-Encoder Training
        recon, mean, std = self.vae(state, action, inferred_mdp)
        recon_loss = F.mse_loss(recon, action)
        KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                          std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * KL_loss

        gt.stamp('get_vae_loss', unique=False)

        self.vae_optimizer.zero_grad()
        vae_loss.backward(retain_graph=True)
        self.vae_optimizer.step()

        gt.stamp('update_vae', unique=False)

        # Critic Training
        self.critic_optimizer.zero_grad()

        with torch.no_grad():

            # Duplicate state 10 times
            state_rep = next_state.repeat_interleave(10, dim=0)
            inferred_mdp_rep = inferred_mdp.repeat_interleave(10, dim=0)

            target_Q1, target_Q2 = self.critic_target(
                state_rep,
                self.actor_target(
                    state_rep,
                    self.vae.decode(state_rep, inferred_mdp=inferred_mdp_rep),
                    inferred_mdp_rep), inferred_mdp_rep)

            # Soft Clipped Double Q-learning
            target_Q = self.target_q_coef * torch.min(target_Q1, target_Q2) + (
                1 - self.target_q_coef) * torch.max(target_Q1, target_Q2)
            target_Q = target_Q.view(state.shape[0], -1).max(1)[0].view(-1, 1)

            target_Q = reward + done * discount * target_Q

        current_Q1, current_Q2 = self.critic(state, action, inferred_mdp)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)

        gt.stamp('get_critic_loss', unique=False)

        self.critic_optimizer.zero_grad()
        critic_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

        gt.stamp('update_critic', unique=False)

        self.context_encoder_optimizer.step()

        # Pertubation Model / Action Training
        sampled_actions = self.vae.decode(state,
                                          inferred_mdp=inferred_mdp.detach())
        perturbed_actions = self.actor(state, sampled_actions,
                                       inferred_mdp.detach())

        # Update through DPG
        actor_loss = -self.critic.q1(state, perturbed_actions,
                                     inferred_mdp.detach()).mean()

        gt.stamp('get_actor_loss', unique=False)

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        gt.stamp('update_actor', unique=False)

        # Update Target Networks
        for param, target_param in zip(self.critic.parameters(),
                                       self.critic_target.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)

        for param, target_param in zip(self.actor.parameters(),
                                       self.actor_target.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)
        """
        Save some statistics for eval   
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['actor_loss'] = np.mean(get_numpy(actor_loss))
            self.eval_statistics['critic_loss'] = np.mean(
                get_numpy(critic_loss))
            self.eval_statistics['vae_loss'] = np.mean(get_numpy(vae_loss))
    def train(self, batch, batch_idxes):
        """
        Unpack data from the batch
        """
        obs = batch['obs']
        actions = batch['actions']
        contexts = batch['contexts']

        num_tasks = batch_idxes.shape[0]

        gt.stamp('unpack_data_from_the_batch', unique=False)

        # Get the in_mdp_batch_size

        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]
        num_trans_context = contexts.shape[0] // batch_idxes.shape[0]

        contexts = contexts.reshape(num_tasks, num_trans_context, -1)
        """
        Infer the context variables
        """
        inferred_mdps = self.context_encoder(contexts)
        inferred_mdps = torch.repeat_interleave(inferred_mdps,
                                                in_mdp_batch_size,
                                                dim=0)

        gt.stamp('infer_mdps', unique=False)
        """
        Obtain the targets
        """
        with torch.no_grad():
            # Sample z for each state
            z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device)

            # Each item in critic_weights is a list that has device count entries
            # each entry in the critic_weights[i] is a list that has num layer entries
            # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size)
            # Similarly to the other weights and biases
            critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies

            # CRITIC
            obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size,
                                           -1)

            obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1)

            target_q = batch_bcq(obs_acs_reshaped, critic_weights,
                                 critic_biases)
            target_q = target_q.reshape(-1)

            # VAE
            z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1)

            tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases)
            tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc)
            target_candidates = tc.reshape(-1, tc.shape[-1])

            # PERTURBATION
            tc_reshaped = target_candidates.reshape(len(batch_idxes),
                                                    in_mdp_batch_size, -1)

            obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1)

            tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases)
            tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp)
            target_perturbations = tp.reshape(-1, tp.shape[-1])

        gt.stamp('get_the_targets', unique=False)
        """
        Obtain the KL loss
        """

        # KL constraint on z if probabilistic
        self.context_encoder_optimizer.zero_grad()

        kl_div = self.context_encoder.compute_kl_div()

        kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1)

        kl_loss = torch.sum(kl_loss_each_task)

        gt.stamp('get_kl_loss', unique=False)
        """
        Obtain the Q-function loss
        """
        self.Qs_optimizer.zero_grad()

        pred_q = self.Qs(obs, actions, inferred_mdps)
        pred_q = torch.squeeze(pred_q)

        qf_loss_each_task = (pred_q - target_q)**2
        qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1)
        qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1)

        qf_loss = torch.mean(qf_loss_each_task)

        gt.stamp('get_qf_loss', unique=False)

        (kl_loss + qf_loss).backward()

        gt.stamp('get_kl_qf_gradient', unique=False)

        self.Qs_optimizer.step()
        self.context_encoder_optimizer.step()

        gt.stamp('update_Qs_encoder', unique=False)
        """
        Obtain the candidate action and perturbation loss
        """

        self.vae_decoder_optimizer.zero_grad()
        self.perturbation_generator_optimizer.zero_grad()

        pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach())
        pred_perturbations = self.perturbation_generator(
            obs, target_candidates, inferred_mdps.detach())

        candidate_loss_each_task = (pred_candidates - target_candidates)**2

        # averaging over action dimension
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)
        candidate_loss_each_task = candidate_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)

        candidate_loss = torch.mean(candidate_loss_each_task)

        perturbation_loss_each_task = (pred_perturbations -
                                       target_perturbations)**2

        # average over action dimension
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)
        perturbation_loss_each_task = perturbation_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)

        perturbation_loss = torch.mean(perturbation_loss_each_task)

        gt.stamp('get_candidate_and_perturbation_loss', unique=False)

        (candidate_loss + perturbation_loss).backward()

        gt.stamp('get_candidate_and_perturbation_gradient', unique=False)

        self.vae_decoder_optimizer.step()
        self.perturbation_generator_optimizer.step()

        gt.stamp('update_vae_perturbation', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy(
                qf_loss_each_task)

            self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss))
            self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy(
                kl_loss_each_task)

            self.eval_statistics['candidate_loss'] = np.mean(
                ptu.get_numpy(candidate_loss))
            self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy(
                candidate_loss_each_task)

            self.eval_statistics['perturbation_loss'] = np.mean(
                ptu.get_numpy(perturbation_loss))
            self.eval_statistics[
                'perturbation_loss_each_task'] = ptu.get_numpy(
                    perturbation_loss_each_task)
def get_optimistic_exploration_action(ob_np,
                                      policy=None,
                                      qfs=None,
                                      hyper_params=None):

    #assert ob_np.ndim == 1

    beta_UB = hyper_params['beta_UB']
    delta = hyper_params['delta']

    #ob = ptu.from_numpy(ob_np)
    ob = {k: ptu.from_numpy(v[None]) for k, v in ob_np.items()}

    # Ensure that ob is not batched
    # assert len(list(ob.shape)) == 1

    _, pre_tanh_mu_T, _, _, std, _ = policy(ob)
    #print(pre_tanh_mu_T.shape)
    pre_tanh_mu_T = pre_tanh_mu_T[0]
    std = std[0]

    # Ensure that pretanh_mu_T is not batched
    assert len(list(pre_tanh_mu_T.shape)) == 1, pre_tanh_mu_T
    assert len(list(std.shape)) == 1

    pre_tanh_mu_T.requires_grad_()
    tanh_mu_T = torch.tanh(pre_tanh_mu_T)

    # Get the upper bound of the Q estimate
    args = [ob, torch.unsqueeze(tanh_mu_T, dim=0)
            ]  #list(torch.unsqueeze(i, dim=0) for i in (ob, tanh_mu_T))
    Q1 = qfs[0](*args)
    Q2 = qfs[1](*args)

    mu_Q = (Q1 + Q2) / 2.0

    sigma_Q = torch.abs(Q1 - Q2) / 2.0

    Q_UB = mu_Q + beta_UB * sigma_Q

    # Obtain the gradient of Q_UB wrt to a
    # with a evaluated at mu_t
    grad = torch.autograd.grad(Q_UB, pre_tanh_mu_T)
    grad = grad[0]

    assert grad is not None
    assert pre_tanh_mu_T.shape == grad.shape

    # Obtain Sigma_T (the covariance of the normal distribution)
    Sigma_T = torch.pow(std, 2)

    # The dividor is (g^T Sigma g) ** 0.5
    # Sigma is diagonal, so this works out to be
    # ( sum_{i=1}^k (g^(i))^2 (sigma^(i))^2 ) ** 0.5
    denom = torch.sqrt(torch.sum(torch.mul(torch.pow(grad, 2),
                                           Sigma_T))) + 10e-6

    # Obtain the change in mu
    mu_C = math.sqrt(2.0 * delta) * torch.mul(Sigma_T, grad) / denom

    assert mu_C.shape == pre_tanh_mu_T.shape

    mu_E = pre_tanh_mu_T + mu_C

    # Construct the tanh normal distribution and sample the exploratory action from it
    assert mu_E.shape == std.shape

    dist = TanhNormal(mu_E, std)

    ac = dist.sample()

    ac_np = ptu.get_numpy(ac)

    # mu_T_np = ptu.get_numpy(pre_tanh_mu_T)
    # mu_C_np = ptu.get_numpy(mu_C)
    # mu_E_np = ptu.get_numpy(mu_E)
    # dict(
    #     mu_T=mu_T_np,
    #     mu_C=mu_C_np,
    #     mu_E=mu_E_np
    # )

    # Return an empty dict, and do not log
    # stats for now
    return ac_np, {}
    def train(self, batch, batch_idxes, epoch):
        """
        Unpack data from the batch
        """
        obs = batch['obs']
        actions = batch['actions']
        next_obs = batch['next_obs']
        qpos = batch['qpos']
        qvel = batch['qvel']

        # Get the in_mdp_batch_size
        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]
        num_tasks = batch_idxes.shape[0]
        """
        Obtain the model prediction loss
        """
        # Note that here, we do not calculate the obs_loss.

        next_obs_loss_task_0 = []
        pred_next_obs = [net(obs, actions) for net in self.network_ensemble]

        for pred_no in pred_next_obs:

            loss = F.mse_loss(pred_no[:in_mdp_batch_size],
                              next_obs[:in_mdp_batch_size])
            next_obs_loss_task_0.append(loss)

        next_obs_magnitude = torch.mean(
            torch.norm(next_obs[:in_mdp_batch_size], dim=1))

        gt.stamp('get_tranistion_prediction_loss', unique=False)

        self.network_ensemble_optimizer.zero_grad()

        next_obs_loss_task_0 = torch.stack(next_obs_loss_task_0)
        next_obs_loss_task_0 = torch.sum(next_obs_loss_task_0)
        next_obs_loss_task_0.backward()

        # [loss.backward() for loss in next_obs_loss_task_0]

        self.network_ensemble_optimizer.step()

        gt.stamp('update', unique=False)
        """
        Save some statistics for eval
        """

        if self._need_to_update_eval_statistics:

            if epoch > 150:

                qpos_other_tasks = [
                    qpos[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)]
                    for i in range(0, batch_idxes.shape[0])
                ]
                qvel_other_tasks = [
                    qvel[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)]
                    for i in range(0, batch_idxes.shape[0])
                ]
                actions_other_tasks = [
                    actions[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)]
                    for i in range(0, batch_idxes.shape[0])
                ]

                pred_next_obs_other_tasks = [
                    torch.cat([
                        pred_no[in_mdp_batch_size * i:in_mdp_batch_size *
                                (i + 1)][..., None]
                        for pred_no in pred_next_obs
                    ],
                              dim=-1) for i in range(0, batch_idxes.shape[0])
                ]

                next_obs_loss_other_tasks = []
                next_obs_loss_other_tasks_std = []
                num_selected_trans_other_tasks = []
                for item in zip(pred_next_obs_other_tasks, qpos_other_tasks,
                                qvel_other_tasks, actions_other_tasks):

                    pred_no_other_task, qp_other_task, qv_other_task, a_other_task = item

                    pred_std = torch.std(pred_no_other_task, dim=-1)
                    pred_std = pred_std.squeeze()
                    pred_std = torch.mean(pred_std, dim=1)

                    mask = ptu.get_numpy(pred_std < self.std_threshold)
                    num_selected_trans_other_tasks.append(np.sum(mask))

                    mask = mask.astype(bool)
                    pred_no_other_task = ptu.get_numpy(pred_no_other_task)
                    pred_no_other_task = pred_no_other_task[mask]

                    qp_other_task = ptu.get_numpy(qp_other_task)
                    qp_other_task = qp_other_task[mask]

                    qv_other_task = ptu.get_numpy(qv_other_task)
                    qv_other_task = qv_other_task[mask]

                    a_other_task = ptu.get_numpy(a_other_task)
                    a_other_task = a_other_task[mask]

                    mse_loss = []
                    for pred_no, qp, qv, a in zip(pred_no_other_task,
                                                  qp_other_task, qv_other_task,
                                                  a_other_task):
                        self.env.set_state(qp, qv)
                        no, _, _, _ = self.env.step(a)

                        loss = (pred_no - no.reshape(-1, 1))**2
                        loss = np.mean(loss, axis=0)
                        mse_loss.append(loss)

                    if len(mse_loss) > 0:
                        mse_loss = np.stack(mse_loss)
                        mse_loss_mean = np.mean(mse_loss)
                        next_obs_loss_other_tasks.append(mse_loss_mean)

                        mse_loss_std = np.std(mse_loss, axis=1)
                        mse_loss_std = np.mean(mse_loss_std)
                        next_obs_loss_other_tasks_std.append(mse_loss_std)

                self.eval_statistics[
                    'average_task_next_obs_loss_other_tasks_mean'] = next_obs_loss_other_tasks
                self.eval_statistics[
                    'average_task_next_obs_loss_other_tasks_std'] = next_obs_loss_other_tasks_std

                self.eval_statistics[
                    'num_selected_trans_other_tasks'] = num_selected_trans_other_tasks

            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            # self.eval_statistics['next_obs_loss_task_0'] = np.mean(
            #     ptu.get_numpy(torch.mean(torch.stack(next_obs_loss_task_0)))
            # )
            self.eval_statistics['next_obs_loss_task_0'] = np.mean(
                ptu.get_numpy(next_obs_loss_task_0 /
                              len(self.network_ensemble)))
            self.eval_statistics['next_obs_magnitude'] = ptu.get_numpy(
                next_obs_magnitude)