Exemple #1
0
 def update_target_networks(self):
     ptu.soft_update_from_to(
         self.qf1, self.target_qf1, self.soft_target_tau
     )
     ptu.soft_update_from_to(
         self.qf2, self.target_qf2, self.soft_target_tau
     )
Exemple #2
0
 def _update_target_network(self, high=True):
     if high:
         ptu.soft_update_from_to(self.high_vf, self.high_vf_target,
                                 self.soft_target_tau)
     else:
         ptu.soft_update_from_to(self.low_vf, self.low_vf_target,
                                 self.soft_target_tau)
Exemple #3
0
 def _update_target_networks(self):
     if self.use_soft_update:
         ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
         ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)
     else:
         if self._n_train_steps_total % self.target_hard_update_period == 0:
             ptu.copy_model_params_from_to(self.qf1, self.target_qf1)
             ptu.copy_model_params_from_to(self.qf2, self.target_qf2)
 def _update_target_networks(self):
     if self.use_soft_update:
         ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
         ptu.soft_update_from_to(self.qf, self.target_qf, self.tau)
     else:
         if self._n_env_steps_total % self.target_hard_update_period == 0:
             ptu.copy_model_params_from_to(self.qf, self.target_qf)
             ptu.copy_model_params_from_to(self.policy, self.target_policy)
    def train_from_torch(self, batch):
        rewards = batch["rewards"]
        terminals = batch["terminals"]
        obs = batch["observations"]
        actions = batch["actions"]
        next_obs = batch["next_observations"]
        try:
            plan_lengths = batch["plan_lengths"]
            if self.single_plan_discounting:
                plan_lengths = torch.ones_like(plan_lengths)
        except KeyError as e:
            plan_lengths = torch.ones_like(rewards)

        """
        Compute loss
        """

        best_action_idxs = self.qf(next_obs).max(1, keepdim=True)[1]
        target_q_values = self.target_qf(next_obs).gather(1, best_action_idxs).detach()
        y_target = (
            rewards
            + (1.0 - terminals)
            * torch.pow(self.discount, plan_lengths)
            * target_q_values
        )
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        if self.huber_loss:
            y_target = torch.max(y_target, y_pred.sub(1))
            y_target = torch.min(y_target, y_pred.add(1))
        qf_loss = self.qf_criterion(y_pred, y_target)

        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        """
        Soft target network updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf, self.target_qf, self.soft_target_tau)

        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics["QF Loss"] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict("Y Predictions", ptu.get_numpy(y_pred))
            )
Exemple #6
0
    def train_from_torch(self, batch):
        rewards = batch['rewards'] * self.reward_scale
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Compute loss
        """
        qf_losses = []
        for ensemble_idx in range(self.ensemble_size):
            qf = self.qfs[ensemble_idx]
            target_qf = self.target_qfs[ensemble_idx]

            target_q_values = target_qf(next_obs).detach().max(
                1, keepdim=True
            )[0]
            y_target = rewards + (1. - terminals) * self.discount * target_q_values
            y_target = y_target.detach()
            # actions is a one-hot vector
            y_pred = torch.sum(qf(obs) * actions, dim=1, keepdim=True)
            qf_loss = self.qf_criterion(y_pred, y_target)
            qf_losses.append(qf_loss)

            """
            Save some statistics for eval using just one batch.
            """
            if self._need_to_update_eval_statistics:
                if ensemble_idx == self.ensemble_size - 1:
                    self._need_to_update_eval_statistics = False
                self.eval_statistics['QF %d  Loss' % ensemble_idx] = np.mean(ptu.get_numpy(qf_loss))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Y %d Predictions' % ensemble_idx,
                    ptu.get_numpy(y_pred),
                ))

        """
        Soft target network updates
        """
        self.qf_optimizer.zero_grad()
        total_qf_loss = sum(qf_losses)
        total_qf_loss.backward()
        self.qf_optimizer.step()

        for ensemble_idx in range(self.ensemble_size):
            qf = self.qfs[ensemble_idx]
            target_qf = self.target_qfs[ensemble_idx]
            """
            Soft Updates
            """
            if self._n_train_steps_total % self.target_update_period == 0:
                ptu.soft_update_from_to(
                    qf, target_qf, self.soft_target_tau
                )
Exemple #7
0
    def _do_training(self):
        batch = self.get_batch()
        """
        Optimize Critic/Actor.
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        _, _, v_pred = self.target_policy(next_obs, None)
        y_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * v_pred
        y_target = y_target.detach()
        mu, y_pred, v = self.policy(obs, actions)
        policy_loss = self.policy_criterion(y_pred, y_target)

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        """
        Update Target Networks
        """
        if self.use_soft_update:
            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
        else:
            if self._n_train_steps_total % self.target_hard_update_period == 0:
                ptu.copy_model_params_from_to(self.policy, self.target_policy)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy v',
                    ptu.get_numpy(v),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(mu),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Y targets',
                    ptu.get_numpy(y_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Y predictions',
                    ptu.get_numpy(y_pred),
                ))
Exemple #8
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Compute loss
        """

        best_action_idxs = self.qf(next_obs).max(
            1, keepdim=True
        )[1]
        target_q_values = self.target_qf(next_obs).gather(
            1, best_action_idxs
        ).detach()
        y_target = rewards + (1. - terminals) * self.discount * target_q_values
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        qf_loss = self.qf_criterion(y_pred, y_target)

        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        """
        Soft target network updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf, self.target_qf, self.soft_target_tau
            )

        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Y Predictions',
                ptu.get_numpy(y_pred),
            ))
            
        self._n_train_steps_total += 1
Exemple #9
0
 def low_train_from_torch(self, batch):
     rewards = batch['rewards']
     terminals = batch['terminals']
     obs = batch['observations']
     actions = batch['actions']
     next_obs = batch['next_observations']
     goals = batch['goals']
     # kinda an approximation since doesn't account for goal switching
     next_goals = self.setter.goal_transition(obs, goals, next_obs)
     """
     Compute loss
     """
     best_action_idxs = self.low_qf(torch.cat(
         (next_obs, next_goals), dim=1)).max(1, keepdim=True)[1]
     target_q_values = self.low_target_qf(
         torch.cat((next_obs, next_goals),
                   dim=1)).gather(1, best_action_idxs).detach()
     y_target = rewards + (1. - terminals) * self.discount * target_q_values
     y_target = y_target.detach()
     # actions is a one-hot vector
     y_pred = torch.sum(self.low_qf(torch.cat(
         (obs, goals), dim=1)) * actions,
                        dim=1,
                        keepdim=True)
     qf_loss = self.qf_criterion(y_pred, y_target)
     """
     Update networks
     """
     self.low_qf_optimizer.zero_grad()
     qf_loss.backward()
     if self.grad_clip_val is not None:
         nn.utils.clip_grad_norm_(self.low_qf.parameters(),
                                  self.grad_clip_val)
     self.low_qf_optimizer.step()
     """
     Soft target network updates
     """
     if self._n_train_steps_total % self.setter_and_target_update_period == 0:
         ptu.soft_update_from_to(self.low_qf, self.low_target_qf, self.tau)
     """
     Save some statistics for eval using just one batch.
     """
     if self._need_to_update_eval_statistics:
         self._need_to_update_eval_statistics = False
         self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
         self.eval_statistics.update(
             create_stats_ordered_dict(
                 'Y Predictions',
                 ptu.get_numpy(y_pred),
             ))
Exemple #10
0
 def _do_training(self, n_steps_total):
     raw_subtraj_batch, start_indices = (
         self.replay_buffer.train_replay_buffer.random_subtrajectories(
             self.num_subtrajs_per_batch))
     subtraj_batch = create_torch_subtraj_batch(raw_subtraj_batch)
     if self.save_memory_gradients:
         subtraj_batch['memories'].requires_grad = True
     self.train_critic(subtraj_batch)
     self.train_policy(subtraj_batch, start_indices)
     if self.use_soft_update:
         ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
         ptu.soft_update_from_to(self.qf, self.target_qf, self.tau)
     else:
         if n_steps_total % self.target_hard_update_period == 0:
             ptu.copy_model_params_from_to(self.qf, self.target_qf)
             ptu.copy_model_params_from_to(self.policy, self.target_policy)
Exemple #11
0
    def train_from_torch(self, batch):
        rewards = batch["rewards"] * self.reward_scale
        terminals = batch["terminals"]
        obs = batch["observations"]
        actions = batch["actions"]
        next_obs = batch["next_observations"]
        """
        Compute loss
        """

        target_q_values = self.target_qf(next_obs).detach().max(
            1, keepdim=True)[0]
        y_target = rewards + (1.0 -
                              terminals) * self.discount * target_q_values
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        qf_loss = self.qf_criterion(y_pred, y_target)
        """
        Soft target network updates
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf, self.target_qf,
                                    self.soft_target_tau)
        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics["QF Loss"] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    "Y Predictions",
                    ptu.get_numpy(y_pred),
                ))
        self._n_train_steps_total += 1
Exemple #12
0
    def _pre_take_step(self, indices, context1, context1_):
        #num_tasks = len(indices)

        # data is (task, batch, feat)

        #positive sample

        self.curl_optimizer.zero_grad()
        self.encoder_optimizer.zero_grad()
        if self.use_information_bottleneck:
            kl_div = self.agent.compute_kl_div()
            kl_loss = self.kl_lambda / 1000 * kl_div
            kl_loss.backward(retain_graph=True)
        loss.backward()
        self.curl_optimizer.step()
        self.encoder_optimizer.step()

        ptu.soft_update_from_to(self.agent.context_encoder,
                                self.agent.context_encoder_target,
                                self.encoder_tau)
Exemple #13
0
 def _update_target_networks(self):
     for cg1, target_cg1, qf1, target_qf1, cg2, target_cg2, qf2, target_qf2 in \
         zip(self.cg1_n, self.target_cg1_n, self.qf1_n, self.target_qf1_n,
             self.cg2_n, self.target_cg2_n, self.qf2_n, self.target_qf2_n):
         if self.use_soft_update:
             ptu.soft_update_from_to(cg1, target_cg1, self.tau)
             ptu.soft_update_from_to(qf1, target_qf1, self.tau)
             ptu.soft_update_from_to(cg2, target_cg2, self.tau)
             ptu.soft_update_from_to(qf2, target_qf2, self.tau)
         else:
             if self._n_train_steps_total % self.target_hard_update_period == 0:
                 ptu.copy_model_params_from_to(cg1, target_cg1)
                 ptu.copy_model_params_from_to(qf1, target_qf1)
                 ptu.copy_model_params_from_to(cg2, target_cg2)
                 ptu.copy_model_params_from_to(qf2, target_qf2)
Exemple #14
0
 def _update_target_networks(self):
     for policy, target_policy, qf, target_qf in \
         zip(self.policy_n, self.target_policy_n, self.qf_n, self.target_qf_n):
         if self.use_soft_update:
             ptu.soft_update_from_to(policy, target_policy, self.tau)
             ptu.soft_update_from_to(qf, target_qf, self.tau)
         else:
             if self._n_train_steps_total % self.target_hard_update_period == 0:
                 ptu.copy_model_params_from_to(qf, target_qf)
                 ptu.copy_model_params_from_to(policy, target_policy)
     if self.double_q:
         for qf2, target_qf2 in zip(self.qf2_n, self.target_qf2_n):
             if self.use_soft_update:
                 ptu.soft_update_from_to(qf2, target_qf2, self.tau)
             else:
                 if self._n_train_steps_total % self.target_hard_update_period == 0:
                     ptu.copy_model_params_from_to(qf2, target_qf2)
Exemple #15
0
    def train_from_torch(
        self,
        batch,
        train=True,
        pretrain=False,
    ):
        """

        :param batch:
        :param train:
        :param pretrain:
        :return:
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.brac:
            buf_dist = self.buffer_policy(obs)
            buf_log_pi = buf_dist.log_prob(actions)
            rewards = rewards + buf_log_pi

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v1_pi = self.qf1(obs, policy_mle)
            v2_pi = self.qf2(obs, policy_mle)
            v_pi = torch.min(v1_pi, v2_pi)
        else:
            if self.vf_K > 1:
                vs = []
                for i in range(self.vf_K):
                    u = dist.sample()
                    q1 = self.qf1(obs, u)
                    q2 = self.qf2(obs, u)
                    v = torch.min(q1, q2)
                    # v = q1
                    vs.append(v)
                v_pi = torch.cat(vs, 1).mean(dim=1)
            else:
                # v_pi = self.qf1(obs, new_obs_actions)
                v1_pi = self.qf1(obs, new_obs_actions)
                v2_pi = self.qf2(obs, new_obs_actions)
                v_pi = torch.min(v1_pi, v2_pi)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        elif self.buffer_policy_sample_actions:
            buf_dist = self.buffer_policy(obs)
            u, _ = buf_dist.rsample_and_logprob()
            qf1_buffer_actions = self.qf1(obs, u)
            qf2_buffer_actions = self.qf2(obs, u)
            q_buffer_actions = torch.min(
                qf1_buffer_actions,
                qf2_buffer_actions,
            )
            if self.awr_min_q:
                q_adv = q_buffer_actions
            else:
                q_adv = qf1_buffer_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)

        if self.normalize_over_state == "advantage":
            score = q_adv - v_pi
            if self.mask_positive_advantage:
                score = torch.sign(score)
        elif self.normalize_over_state == "Z":
            buffer_dist = self.buffer_policy(obs)
            K = self.Z_K
            buffer_obs = []
            buffer_actions = []
            log_bs = []
            log_pis = []
            for i in range(K):
                u = buffer_dist.sample()
                log_b = buffer_dist.log_prob(u)
                log_pi = dist.log_prob(u)
                buffer_obs.append(obs)
                buffer_actions.append(u)
                log_bs.append(log_b)
                log_pis.append(log_pi)
            buffer_obs = torch.cat(buffer_obs, 0)
            buffer_actions = torch.cat(buffer_actions, 0)
            p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, ))
            log_pi = torch.cat(log_pis, 0)
            log_pi = log_pi.sum(dim=1, )
            q1_b = self.qf1(buffer_obs, buffer_actions)
            q2_b = self.qf2(buffer_obs, buffer_actions)
            q_b = torch.min(q1_b, q2_b)
            q_b = torch.reshape(q_b, (-1, K))
            adv_b = q_b - v_pi
            # if self._n_train_steps_total % 100 == 0:
            #     import ipdb; ipdb.set_trace()
            # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True)
            # score = torch.exp((q_adv - v_pi) / beta) / Z
            # score = score / sum(score)
            logK = torch.log(ptu.tensor(float(K)))
            logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True)
            logS = (q_adv - v_pi) / beta - logZ
            # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True)
            # logS = q_adv/beta - logZ
            score = F.softmax(logS, dim=0)  # score / sum(score)
        else:
            error

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        if self.weight_loss and weights is None:
            if self.normalize_over_batch:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif not self.normalize_over_batch:
                weights = score
            else:
                error
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0:
            del self.buffer_policy_optimizer
            self.buffer_policy_optimizer = self.optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=self.policy_weight_decay,
                lr=self.policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            for i in range(self.num_buffer_policy_train_steps_on_reset):
                if self.train_bc_on_rl_buffer:
                    if self.advantage_weighted_buffer_loss:
                        buffer_dist = self.buffer_policy(obs)
                        buffer_u = actions
                        buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob(
                        )
                        buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                        buffer_policy_logpp = buffer_policy_logpp[:, None]

                        buffer_q1_pred = self.qf1(obs, buffer_u)
                        buffer_q2_pred = self.qf2(obs, buffer_u)
                        buffer_q_adv = torch.min(buffer_q1_pred,
                                                 buffer_q2_pred)

                        buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                        buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                        buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                        buffer_score = buffer_q_adv - buffer_v_pi
                        buffer_weights = F.softmax(buffer_score / beta, dim=0)
                        buffer_policy_loss = self.awr_weight * (
                            -buffer_policy_logpp * len(buffer_weights) *
                            buffer_weights.detach()).mean()
                    else:
                        buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                            self.replay_buffer.train_replay_buffer,
                            self.buffer_policy)

                    self.buffer_policy_optimizer.zero_grad()
                    buffer_policy_loss.backward(retain_graph=True)
                    self.buffer_policy_optimizer.step()

        if self.train_bc_on_rl_buffer:
            if self.advantage_weighted_buffer_loss:
                buffer_dist = self.buffer_policy(obs)
                buffer_u = actions
                buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob()
                buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                buffer_policy_logpp = buffer_policy_logpp[:, None]

                buffer_q1_pred = self.qf1(obs, buffer_u)
                buffer_q2_pred = self.qf2(obs, buffer_u)
                buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred)

                buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                buffer_score = buffer_q_adv - buffer_v_pi
                buffer_weights = F.softmax(buffer_score / beta, dim=0)
                buffer_policy_loss = self.awr_weight * (
                    -buffer_policy_logpp * len(buffer_weights) *
                    buffer_weights.detach()).mean()
            else:
                buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0:
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))

            if self.normalize_over_state == "Z":
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'logZ',
                        ptu.get_numpy(logZ),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                _, buffer_train_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)

                _, buffer_test_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.validation_replay_buffer,
                    self.buffer_policy)
                buffer_dist = self.buffer_policy(obs)
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                _, train_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_train_buffer, self.buffer_policy)

                _, test_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.buffer_policy)

                self.eval_statistics.update({
                    "buffer_policy/Train Online Logprob":
                    -1 * ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Online Logprob":
                    -1 * ptu.get_numpy(buffer_test_logp_loss),
                    "buffer_policy/Train Offline Logprob":
                    -1 * ptu.get_numpy(train_offline_logp_loss),
                    "buffer_policy/Test Offline Logprob":
                    -1 * ptu.get_numpy(test_offline_logp_loss),
                    "buffer_policy/train_policy_loss":
                    ptu.get_numpy(buffer_policy_loss),
                    # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss),
                    "buffer_policy/kl_div":
                    ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

            if self.validation_qlearning:
                train_data = self.replay_buffer.validation_replay_buffer.random_batch(
                    self.bc_batch_size)
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                # goals = train_data['resampled_goals']
                train_data[
                    'observations'] = obs  # torch.cat((obs, goals), dim=1)
                train_data[
                    'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
                self.test_from_torch(train_data)

        self._n_train_steps_total += 1
Exemple #16
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(
            noise,
            -self.target_policy_noise_clip,
            self.target_policy_noise_clip
        )
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target) ** 2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target) ** 2
        qf2_loss = bellman_errors_2.mean()

        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = - q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = - q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 1',
                ptu.get_numpy(bellman_errors_1),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 2',
                ptu.get_numpy(bellman_errors_2),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy Action',
                ptu.get_numpy(policy_actions),
            ))
        self._n_train_steps_total += 1
Exemple #17
0
    def train_from_torch(self, batch, train=True, pretrain=False,):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        target_vf_pred = self.vf(next_obs).detach()

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_vf_pred
        q_target = q_target.detach()
        qf1_loss = self.qf_criterion(q1_pred, q_target)
        qf2_loss = self.qf_criterion(q2_pred, q_target)

        """
        VF Loss
        """
        q_pred = torch.min(
            self.target_qf1(obs, actions),
            self.target_qf2(obs, actions),
        ).detach()
        vf_pred = self.vf(obs)
        vf_err = vf_pred - q_pred
        vf_sign = (vf_err > 0).float()
        vf_weight = (1 - vf_sign) * self.quantile + vf_sign * (1 - self.quantile)
        vf_loss = (vf_weight * (vf_err ** 2)).mean()

        """
        Policy Loss
        """
        policy_logpp = dist.log_prob(actions)

        adv = q_pred - vf_pred
        exp_adv = torch.exp(adv / self.beta)
        if self.clip_score is not None:
            exp_adv = torch.clamp(exp_adv, max=self.clip_score)

        weights = exp_adv[:, 0].detach()
        policy_loss = (-policy_logpp * weights).mean()

        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

            self.vf_optimizer.zero_grad()
            vf_loss.backward()
            self.vf_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'rewards',
                ptu.get_numpy(rewards),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'terminals',
                ptu.get_numpy(terminals),
            ))
            self.eval_statistics['replay_buffer_len'] = self.replay_buffer._size
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(create_stats_ordered_dict(
                'Advantage Weights',
                ptu.get_numpy(weights),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Advantage Score',
                ptu.get_numpy(adv),
            ))

            self.eval_statistics.update(create_stats_ordered_dict(
                'V1 Predictions',
                ptu.get_numpy(vf_pred),
            ))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))

        self._n_train_steps_total += 1
Exemple #18
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Behavior clone a policy
        """
        recon, mean, std = self.vae(obs, actions)
        recon_loss = self.qf_criterion(recon, actions)
        kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                          std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * kl_loss

        self.vae_optimizer.zero_grad()
        vae_loss.backward()
        self.vae_optimizer.step()
        """
        Critic Training
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            # Duplicate state 10 times (10 is a hyperparameter chosen by BCQ)
            state_rep = next_obs.unsqueeze(1).repeat(1, 10, 1).view(
                next_obs.shape[0] * 10, next_obs.shape[1])  # 10BxS
            # Compute value of perturbed actions sampled from the VAE
            action_rep = self.policy(state_rep)[0]
            target_qf1 = self.target_qf1(state_rep, action_rep)
            target_qf2 = self.target_qf2(state_rep, action_rep)

            # Soft Clipped Double Q-learning
            target_Q = 0.75 * torch.min(target_qf1,
                                        target_qf2) + 0.25 * torch.max(
                                            target_qf1, target_qf2)
            target_Q = target_Q.view(next_obs.shape[0],
                                     -1).max(1)[0].view(-1, 1)
            target_Q = self.reward_scale * rewards + (
                1.0 - terminals) * self.discount * target_Q  # Bx1

        qf1_pred = self.qf1(obs, actions)  # Bx1
        qf2_pred = self.qf2(obs, actions)  # Bx1
        qf1_loss = (qf1_pred - target_Q.detach()).pow(2).mean()
        qf2_loss = (qf2_pred - target_Q.detach()).pow(2).mean()
        """
        Actor Training
        """
        sampled_actions, raw_sampled_actions = self.vae.decode_multiple(
            obs, num_decode=self.num_samples_mmd_match)
        actor_samples, _, _, _, _, _, _, raw_actor_actions = self.policy(
            obs.unsqueeze(1).repeat(1, self.num_samples_mmd_match,
                                    1).view(-1, obs.shape[1]),
            return_log_prob=True)
        actor_samples = actor_samples.view(obs.shape[0],
                                           self.num_samples_mmd_match,
                                           actions.shape[1])
        raw_actor_actions = raw_actor_actions.view(obs.shape[0],
                                                   self.num_samples_mmd_match,
                                                   actions.shape[1])

        if self.kernel_choice == 'laplacian':
            mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions,
                                               raw_actor_actions,
                                               sigma=self.mmd_sigma)
        elif self.kernel_choice == 'gaussian':
            mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions,
                                              raw_actor_actions,
                                              sigma=self.mmd_sigma)

        action_divergence = ((sampled_actions - actor_samples)**2).sum(-1)
        raw_action_divergence = ((raw_sampled_actions -
                                  raw_actor_actions)**2).sum(-1)

        q_val1 = self.qf1(obs, actor_samples[:, 0, :])
        q_val2 = self.qf2(obs, actor_samples[:, 0, :])

        if self.policy_update_style == '0':
            policy_loss = torch.min(q_val1, q_val2)[:, 0]
        elif self.policy_update_style == '1':
            policy_loss = torch.mean(q_val1, q_val2)[:, 0]

        if self._n_train_steps_total >= 40000:
            # Now we can update the policy
            if self.mode == 'auto':
                policy_loss = (-policy_loss + self.log_alpha.exp() *
                               (mmd_loss - self.target_mmd_thresh)).mean()
            else:
                policy_loss = (-policy_loss + 100 * mmd_loss).mean()
        else:
            if self.mode == 'auto':
                policy_loss = (self.log_alpha.exp() *
                               (mmd_loss - self.target_mmd_thresh)).mean()
            else:
                policy_loss = 100 * mmd_loss.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        if self.mode == 'auto':
            policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()

        if self.mode == 'auto':
            self.alpha_optimizer.zero_grad()
            (-policy_loss).backward()
            self.alpha_optimizer.step()
            self.log_alpha.data.clamp_(min=-5.0, max=10.0)
        """
        Update networks
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics[
                'Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(qf1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(qf2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(target_Q),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict('MMD Loss', ptu.get_numpy(mmd_loss)))
            self.eval_statistics.update(
                create_stats_ordered_dict('Action Divergence',
                                          ptu.get_numpy(action_divergence)))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Raw Action Divergence',
                    ptu.get_numpy(raw_action_divergence)))
            if self.mode == 'auto':
                self.eval_statistics['Alpha'] = self.log_alpha.exp().item()

        self._n_train_steps_total += 1
Exemple #19
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy and Alpha Loss
        """
        pis = self.policy(obs)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(pis.detach() * self.log_alpha.exp() *
                           (torch.log(pis + 1e-3) +
                            self.target_entropy).detach()).sum(-1).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        min_q = torch.min(self.qf1(obs), self.qf2(obs)).detach()
        policy_loss = (pis *
                       (alpha * torch.log(pis + 1e-3) - min_q)).sum(-1).mean()
        """
        QF Loss
        """
        new_pis = self.policy(next_obs).detach()
        target_min_q_values = torch.min(
            self.target_qf1(next_obs),
            self.target_qf2(next_obs),
        )
        target_q_values = (
            new_pis *
            (target_min_q_values - alpha * torch.log(new_pis + 1e-3))).sum(
                -1, keepdim=True)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values

        q1_pred = torch.sum(self.qf1(obs) * actions.detach(),
                            dim=-1,
                            keepdim=True)
        q2_pred = torch.sum(self.qf2(obs) * actions.detach(),
                            dim=-1,
                            keepdim=True)
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Pis',
                    ptu.get_numpy(pis),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
Exemple #20
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']

        q1_pred = self.qf1(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_pred = self.qf2(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        # Make sure policy accounts for squashing functions like tanh correctly!
        policy_outputs = self.policy(
            obs,
            goals,
            num_steps_left,
            reparameterize=self.train_policy_with_reparameterization,
            return_log_prob=True)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]
        if not self.dense_rewards:
            log_pi = log_pi * terminals
        """
        QF Loss
        """
        target_v_values = self.target_vf(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_v_values
        q_target = q_target.detach()
        bellman_errors_1 = (q1_pred - q_target)**2
        bellman_errors_2 = (q2_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()
        qf2_loss = bellman_errors_2.mean()
        """
        VF Loss
        """
        q1_new_actions = self.qf1(
            observations=obs,
            actions=new_actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_new_actions = self.qf2(
            observations=obs,
            actions=new_actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q_new_actions = torch.min(q1_new_actions, q2_new_actions)
        v_target = q_new_actions - log_pi
        v_pred = self.vf(
            observations=obs,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        v_target = v_target.detach()
        bellman_errors = (v_pred - v_target)**2
        vf_loss = bellman_errors.mean()
        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()
        """
        Policy Loss
        """
        # paper says to do + but apparently that's a typo. Do Q - V.
        if self.train_policy_with_reparameterization:
            policy_loss = (log_pi - q_new_actions).mean()
        else:
            log_policy_target = q_new_actions - v_pred
            policy_loss = (log_pi *
                           (log_pi - log_policy_target).detach()).mean()
        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value**2).sum(dim=1).mean())
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        if self._n_train_steps_total % self.policy_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.vf, self.target_vf,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Predictions',
                    ptu.get_numpy(v_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
Exemple #21
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        v_pred = self.vf(obs)
        # Make sure policy accounts for squashing functions like tanh correctly!
        policy_outputs = self.policy(obs,
                                     reparameterize=self.train_policy_with_reparameterization,
                                     return_log_prob=True)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]

        """
        Alpha Loss (if applicable)
        """
        if self.use_automatic_entropy_tuning:
            """
            Alpha Loss
            """
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = 1
            alpha_loss = 0

        """
        QF Loss
        """
        target_v_values = self.target_vf(next_obs)
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_v_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        VF Loss
        """
        q_new_actions = torch.min(
            self.qf1(obs, new_actions),
            self.qf2(obs, new_actions),
        )
        v_target = q_new_actions - alpha*log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())

        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        policy_loss = None
        if self._n_train_steps_total % self.policy_update_period == 0:
            """
            Policy Loss
            """
            if self.train_policy_with_reparameterization:
                policy_loss = (alpha*log_pi - q_new_actions).mean()
            else:
                log_policy_target = q_new_actions - v_pred
                policy_loss = (
                    log_pi * (alpha*log_pi - log_policy_target).detach()
                ).mean()
            mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
            std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
            pre_tanh_value = policy_outputs[-1]
            pre_activation_reg_loss = self.policy_pre_activation_weight * (
                (pre_tanh_value**2).sum(dim=1).mean()
            )
            policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
            policy_loss = policy_loss + policy_reg_loss

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.vf, self.target_vf, self.soft_target_tau
            )

        """
        Save some statistics for eval using just one batch.
        """
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            if policy_loss is None:
                if self.train_policy_with_reparameterization:
                    policy_loss = (log_pi - q_new_actions).mean()
                else:
                    log_policy_target = q_new_actions - v_pred
                    policy_loss = (
                        log_pi * (log_pi - log_policy_target).detach()
                    ).mean()

                mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
                std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
                pre_tanh_value = policy_outputs[-1]
                pre_activation_reg_loss = self.policy_pre_activation_weight * (
                    (pre_tanh_value**2).sum(dim=1).mean()
                )
                policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
                policy_loss = policy_loss + policy_reg_loss

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'V Predictions',
                ptu.get_numpy(v_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
Exemple #22
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']
        """
        Critic operations.
        """
        next_actions = self.target_policy(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q2_values = self.target_qf2(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_pred = self.qf2(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )

        bellman_errors_1 = (q1_pred - q_target)**2
        bellman_errors_2 = (q2_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions, pre_tanh_value = self.policy(
            obs,
            goals,
            num_steps_left,
            return_preactivations=True,
        )
        q_output = self.qf1(
            observations=obs,
            actions=policy_actions,
            num_steps_left=num_steps_left,
            goals=goals,
        )

        policy_loss = -q_output.mean()

        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics = OrderedDict()
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman1 Errors',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman2 Errors',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
Exemple #23
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )
        
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if self.num_qs == 1:
            q_new_actions = self.qf1(obs, new_obs_actions)
        else:
            q_new_actions = torch.min(
                self.qf1(obs, new_obs_actions),
                self.qf2(obs, new_obs_actions),
            )

        policy_loss = (alpha*log_pi - q_new_actions).mean()

        if self._current_epoch < self.policy_eval_start:
            """
            For the initial few epochs, try doing behaivoral cloning, if needed
            conventionally, there's not much difference in performance with having 20k 
            gradient steps here, or not having it
            """
            policy_log_prob = self.policy.log_prob(obs, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()
        
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        if self.num_qs > 1:
            q2_pred = self.qf2(obs, actions)
        
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if not self.max_q_backup:
            if self.num_qs == 1:
                target_q_values = self.target_qf1(next_obs, new_next_actions)
            else:
                target_q_values = torch.min(
                    self.target_qf1(next_obs, new_next_actions),
                    self.target_qf2(next_obs, new_next_actions),
                )
            
            if not self.deterministic_backup:
                target_q_values = target_q_values - alpha * new_log_pi
        
        if self.max_q_backup:
            """when using max q backup"""
            next_actions_temp, _ = self._get_policy_actions(next_obs, num_actions=10, network=self.policy)
            target_qf1_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf1).max(1)[0].view(-1, 1)
            target_qf2_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf2).max(1)[0].view(-1, 1)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()
            
        qf1_loss = self.qf_criterion(q1_pred, q_target)
        if self.num_qs > 1:
            qf2_loss = self.qf_criterion(q2_pred, q_target)

        ## add CQL
        random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1) # .cuda()
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random, network=self.policy)
        q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1)
        q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2)
        q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
        q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
        q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)
        q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)

        cat_q1 = torch.cat(
            [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
        )
        cat_q2 = torch.cat(
            [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
        )
        std_q1 = torch.std(cat_q1, dim=1)
        std_q2 = torch.std(cat_q2, dim=1)

        if self.min_q_version == 3:
            # importance sammpled version
            random_density = np.log(0.5 ** curr_actions_tensor.shape[-1])
            cat_q1 = torch.cat(
                [q1_rand - random_density, q1_next_actions - new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()], 1
            )
            cat_q2 = torch.cat(
                [q2_rand - random_density, q2_next_actions - new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()], 1
            )
            
        min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
        min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
                    
        """Subtract the log likelihood of data"""
        min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.min_q_weight
        min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.min_q_weight
        
        if self.with_lagrange:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
            min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap)
            min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap)

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss = (-min_qf1_loss - min_qf2_loss)*0.5 
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_optimizer.step()

        qf1_loss = qf1_loss + min_qf1_loss
        qf2_loss = qf2_loss + min_qf2_loss

        """
        Update networks
        """
        # Update the Q-functions iff 
        self._num_q_update_steps += 1
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.qf1_optimizer.step()

        if self.num_qs > 1:
            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            self.qf2_optimizer.step()

        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()

        """
        Soft Updates
        """
        ptu.soft_update_from_to(
            self.qf1, self.target_qf1, self.soft_target_tau
        )
        if self.num_qs > 1:
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['min QF1 Loss'] = np.mean(ptu.get_numpy(min_qf1_loss))
            if self.num_qs > 1:
                self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
                self.eval_statistics['min QF2 Loss'] = np.mean(ptu.get_numpy(min_qf2_loss))

            if not self.discrete:
                self.eval_statistics['Std QF1 values'] = np.mean(ptu.get_numpy(std_q1))
                self.eval_statistics['Std QF2 values'] = np.mean(ptu.get_numpy(std_q2))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 in-distribution values',
                    ptu.get_numpy(q1_curr_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 in-distribution values',
                    ptu.get_numpy(q2_curr_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 random values',
                    ptu.get_numpy(q1_rand),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 random values',
                    ptu.get_numpy(q2_rand),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 next_actions values',
                    ptu.get_numpy(q1_next_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 next_actions values',
                    ptu.get_numpy(q2_next_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'actions', 
                    ptu.get_numpy(actions)
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards)
                ))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            if self.num_qs > 1:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            if not self.discrete:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
            
            if self.with_lagrange:
                self.eval_statistics['Alpha_prime'] = alpha_prime.item()
                self.eval_statistics['min_q1_loss'] = ptu.get_numpy(min_qf1_loss).mean()
                self.eval_statistics['min_q2_loss'] = ptu.get_numpy(min_qf2_loss).mean()
                self.eval_statistics['threshold action gap'] = self.target_action_gap
                self.eval_statistics['alpha prime loss'] = alpha_prime_loss.item()
            
        self._n_train_steps_total += 1
Exemple #24
0
    def _do_training_step(self, epoch, loop_iter):
        '''
            Train the discriminator
        '''
        self.encoder_optimizer.zero_grad()
        self.policy_optimizer.zero_grad()

        # prep the batches
        # OLD VERSION -----------------------------------------------------------------------------
        # context_batch, context_pred_batch, test_pred_batch, mask = self._get_training_batch()

        # post_dist = self.encoder(context_batch, mask)
        # z = post_dist.sample() # N_tasks x Dim
        # # z = post_dist.mean

        # # convert it to a pytorch tensor
        # # note that our objective says we should maximize likelihood of
        # # BOTH the context_batch and the test_batch
        # obs_batch = np.concatenate((context_pred_batch['observations'], test_pred_batch['observations']), axis=0)
        # obs_batch = Variable(ptu.from_numpy(obs_batch), requires_grad=False)

        # acts_batch = np.concatenate((context_pred_batch['actions'], test_pred_batch['actions']), axis=0)
        # acts_batch = Variable(ptu.from_numpy(acts_batch), requires_grad=False)

        # # make z's for expert samples
        # context_pred_z = z.repeat(1, self.num_context_trajs_for_training * self.train_samples_per_traj).view(
        #     -1,
        #     z.size(1)
        # )
        # test_pred_z = z.repeat(1, self.num_test_trajs_for_training * self.train_samples_per_traj).view(
        #     -1,
        #     z.size(1)
        # )
        # z_batch = torch.cat([context_pred_z, test_pred_z], dim=0)
        # NEW VERSION (this is more fair to this model) -------------------------------------------
        context_batch, mask, pred_batch = self._get_training_batch(epoch)

        post_dist = self.encoder(context_batch, mask)
        z = post_dist.sample()  # N_tasks x Dim
        # z = post_dist.mean

        obs_batch = Variable(ptu.from_numpy(pred_batch['observations']),
                             requires_grad=False)
        acts_batch = Variable(ptu.from_numpy(pred_batch['actions']),
                              requires_grad=False)
        z_batch = z.repeat(1, self.policy_optim_batch_size_per_task).view(
            -1, z.size(1))

        input_batch = torch.cat([obs_batch, z_batch], dim=-1)

        if self.use_mse_objective:
            pred_acts = self.policy(input_batch)[1]
            recon_loss = self.mse_loss(pred_acts, acts_batch)
        else:
            recon_loss = -1.0 * self.policy.get_log_prob(
                input_batch, acts_batch).mean()

        # add KL loss term
        cur_KL_beta = linear_schedule(
            self._n_train_steps_total * self.num_update_loops_per_train_call +
            loop_iter - self.KL_ramp_up_start_iter, 0.0, self.max_KL_beta,
            self.KL_ramp_up_end_iter - self.KL_ramp_up_start_iter)
        KL_loss = self._compute_KL_loss(post_dist)
        if cur_KL_beta == 0.0: KL_loss = KL_loss.detach()

        loss = recon_loss + cur_KL_beta * KL_loss
        loss.backward()

        self.policy_optimizer.step()
        self.encoder_optimizer.step()

        if self.use_target_policy:
            ptu.soft_update_from_to(self.policy, self.target_policy,
                                    self.soft_target_policy_tau)
        if self.use_target_enc:
            ptu.soft_update_from_to(self.encoder, self.target_enc,
                                    self.soft_target_enc_tau)
        """
        Save some statistics for eval
        """
        if self.eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics = OrderedDict()
            if self.use_target_policy:
                enc_to_use = self.target_enc if self.use_target_enc else self.encoder
                pol_to_use = self.target_policy

                if self.use_mse_objective:
                    pred_acts = pol_to_use(input_batch)[1]
                    target_loss = self.mse_loss(pred_acts, acts_batch)
                    self.eval_statistics['Target MSE Loss'] = np.mean(
                        ptu.get_numpy(target_loss))
                else:
                    target_loss = -1.0 * pol_to_use.get_log_prob(
                        input_batch, acts_batch).mean()
                    self.eval_statistics['Target Neg Log Like'] = np.mean(
                        ptu.get_numpy(target_loss))
            else:
                if self.use_mse_objective:
                    self.eval_statistics['Target MSE Loss'] = np.mean(
                        ptu.get_numpy(recon_loss))
                else:
                    self.eval_statistics['Target Neg Log Like'] = np.mean(
                        ptu.get_numpy(recon_loss))
            self.eval_statistics['Target KL'] = np.mean(ptu.get_numpy(KL_loss))
            self.eval_statistics['Cur KL Beta'] = cur_KL_beta
            self.eval_statistics['Max KL Beta'] = self.max_KL_beta

            self.eval_statistics['Avg Post Mean Abs'] = np.mean(
                np.abs(ptu.get_numpy(post_dist.mean)))
            self.eval_statistics['Avg Post Cov Abs'] = np.mean(
                np.abs(ptu.get_numpy(post_dist.cov)))
Exemple #25
0
 def _update_target_network(self):
     ptu.soft_update_from_to(self.vf, self.target_vf, self.soft_target_tau)
Exemple #26
0
    def _take_step(self, indices, obs_enc, act_enc, rewards_enc):

        num_tasks = len(indices)

        import time
        t6 = time.time()

        # data is (task, batch, feat)
        batch = self.replay_loader.next()
        # print('sample', time.time() - t6)
        t7 = time.time()
        obs, actions, rewards, next_obs, terms = [x.cuda() for x in batch]
        # print('to_cuda', time.time() - t7)

        t5 = time.time()
        enc_data = self.prepare_encoder_data(obs_enc, act_enc, rewards_enc)
        # print('prep enc data', time.time() - t5)

        self.cnn_optimizer.zero_grad()
        self.qf1_optimizer.zero_grad()
        self.context_optimizer.zero_grad()

        t5 = time.time()

        # run inference in networks
        q1_pred, q1_next_pred, q2_next_pred, policy_outputs, task_z = self.policy(obs, actions, next_obs, enc_data, obs_enc, act_enc)
        #print('policy', time.time() - t5)

        # new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]
        new_actions = policy_outputs

        # KL constraint on z if probabilistic

        t4 = time.time()
        kl_loss = 0
        if self.use_information_bottleneck:
            kl_div = self.policy.compute_kl_div()
            kl_loss = self.kl_lambda * kl_div
        #print('kl', time.time() - t4)

            # kl_loss.backward(retain_graph=True)

        # qf and encoder update (note encoder does not get grads from policy or vf)
        rewards_flat = rewards.view(self.batch_size * num_tasks, -1)

        # scale rewards for Bellman update
        rewards_flat = rewards_flat * self.reward_scale
        terms_flat = terms.view(self.batch_size * num_tasks, -1)
        actions = actions.view(self.batch_size * num_tasks, -1)

        t3 = time.time()

        best_action_idxs = q1_next_pred.max(
            1, keepdim=True
        )[1]
        target_q_values = q2_next_pred.gather(
            1, best_action_idxs
        ).detach()

        #print('get actions', time.time() - t3)

        y_target = rewards_flat + (1. - terms_flat) * self.discount * target_q_values
        y_target = y_target.detach()
        t2 = time.time()

        # actions is a one-hot vector
        y_pred = torch.sum(q1_pred * actions, dim=1, keepdim=True)
        qf_loss = self.qf_criterion(y_pred, y_target)

        #print('compute loss', time.time() - t2)
        t1 = time.time()

        """
        Update networks
        """
        loss = qf_loss + kl_loss
        loss.backward()
        #print('backward', time.time() - t1)
        t0 = time.time()

        self.qf1_optimizer.step()
        self.cnn_optimizer.step()  
        self.context_optimizer.step()
        #print('step', time.time() - t0)

        """
        Soft target network updates
        """
        if self.target_update_period > 1 and self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.policy.qf1, self.policy.qf2, 1
            )
        else:
            ptu.soft_update_from_to(
                self.policy.qf1, self.policy.qf2, self.soft_target_tau,
            )
            
        # save some statistics for eval
        if self.eval_statistics is None:
            # eval should set this to None.
            # this way, these statistics are only computed for one batch.
            # TODO this is kind of annoying and higher variance, why not just average
            # across all the train steps?
            self.eval_statistics = OrderedDict()
            if self.use_information_bottleneck:
                z_mean = np.mean(np.abs(ptu.get_numpy(self.policy.z_dists[0].mean)))
                z_sig = np.mean(ptu.get_numpy(self.policy.z_dists[0].variance))
                self.eval_statistics['Z mean train'] = z_mean
                self.eval_statistics['Z variance train'] = z_sig
                self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div)
                self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss)

            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            # self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            # self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
            #     policy_loss
            # ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Predictions',
                ptu.get_numpy(q1_pred),
            ))
Exemple #27
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs,
            reparameterize=True,
            return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
Exemple #28
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        masks = batch['masks']
        
        # variables for logging
        tot_qf1_loss, tot_qf2_loss, tot_q1_pred, tot_q2_pred, tot_q_target = 0, 0, 0, 0, 0
        tot_log_pi, tot_policy_mean, tot_policy_log_std, tot_policy_loss = 0, 0, 0, 0
        tot_alpha, tot_alpha_loss = 0, 0
        
        std_Q_actor_list = self.corrective_feedback(obs=obs, update_type=0)
        std_Q_critic_list = self.corrective_feedback(obs=next_obs, update_type=1)
        
        for en_index in range(self.num_ensemble):
            mask = masks[:,en_index].reshape(-1, 1)

            """
            Policy and Alpha Loss
            """
            new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy[en_index](
                obs, reparameterize=True, return_log_prob=True,
            )
            if self.use_automatic_entropy_tuning:
                alpha_loss = -(self.log_alpha[en_index] * (log_pi + self.target_entropy).detach()) * mask
                alpha_loss = alpha_loss.sum() / (mask.sum() + 1)
                self.alpha_optimizer[en_index].zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer[en_index].step()
                alpha = self.log_alpha[en_index].exp()
            else:
                alpha_loss = 0
                alpha = 1

            q_new_actions = torch.min(
                self.qf1[en_index](obs, new_obs_actions),
                self.qf2[en_index](obs, new_obs_actions),
            )
            
            if self.feedback_type == 0 or self.feedback_type == 2:
                std_Q = std_Q_actor_list[en_index]
            else:
                std_Q = std_Q_actor_list[0]
                
            if self.feedback_type == 1 or self.feedback_type == 0:
                weight_actor_Q = torch.sigmoid(-std_Q*self.temperature_act) + 0.5
            else:
                weight_actor_Q = 2*torch.sigmoid(-std_Q*self.temperature_act)
            policy_loss = (alpha*log_pi - q_new_actions - self.expl_gamma * std_Q) * mask * weight_actor_Q.detach()
            policy_loss = policy_loss.sum() / (mask.sum() + 1)

            """
            QF Loss
            """
            q1_pred = self.qf1[en_index](obs, actions)
            q2_pred = self.qf2[en_index](obs, actions)
            
            # Make sure policy accounts for squashing functions like tanh correctly!
            new_next_actions, _, _, new_log_pi, *_ = self.policy[en_index](
                next_obs, reparameterize=True, return_log_prob=True,
            )
            target_q_values = torch.min(
                self.target_qf1[en_index](next_obs, new_next_actions),
                self.target_qf2[en_index](next_obs, new_next_actions),
            ) - alpha * new_log_pi
            
            if self.feedback_type == 0 or self.feedback_type == 2:
                if self.feedback_type == 0:
                    weight_target_Q = torch.sigmoid(-std_Q_critic_list[en_index]*self.temperature) + 0.5
                else:
                    weight_target_Q = 2*torch.sigmoid(-std_Q_critic_list[en_index]*self.temperature)
            else:
                if self.feedback_type == 1:
                    weight_target_Q = torch.sigmoid(-std_Q_critic_list[0]*self.temperature) + 0.5
                else:
                    weight_target_Q = 2*torch.sigmoid(-std_Q_critic_list[0]*self.temperature)
            q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
            qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) * mask * (weight_target_Q.detach())
            qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) * mask * (weight_target_Q.detach())
            qf1_loss = qf1_loss.sum() / (mask.sum() + 1)
            qf2_loss = qf2_loss.sum() / (mask.sum() + 1)
            
            """
            Update networks
            """
            self.qf1_optimizer[en_index].zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer[en_index].step()

            self.qf2_optimizer[en_index].zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer[en_index].step()

            self.policy_optimizer[en_index].zero_grad()
            policy_loss.backward()
            self.policy_optimizer[en_index].step()

            """
            Soft Updates
            """
            if self._n_train_steps_total % self.target_update_period == 0:
                ptu.soft_update_from_to(
                    self.qf1[en_index], self.target_qf1[en_index], self.soft_target_tau
                )
                ptu.soft_update_from_to(
                    self.qf2[en_index], self.target_qf2[en_index], self.soft_target_tau
                )
                
            """
            Statistics for log
            """
            tot_qf1_loss += qf1_loss * (1/self.num_ensemble)
            tot_qf2_loss += qf2_loss * (1/self.num_ensemble)
            tot_q1_pred += q1_pred * (1/self.num_ensemble)
            tot_q2_pred += q2_pred * (1/self.num_ensemble)
            tot_q_target += q_target * (1/self.num_ensemble)
            tot_log_pi += log_pi * (1/self.num_ensemble)
            tot_policy_mean += policy_mean * (1/self.num_ensemble)
            tot_policy_log_std += policy_log_std * (1/self.num_ensemble)
            tot_alpha += alpha.item() * (1/self.num_ensemble)
            tot_alpha_loss += alpha_loss.item()
            tot_policy_loss = (log_pi - q_new_actions).mean() * (1/self.num_ensemble)

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(tot_qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(tot_qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                tot_policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(tot_q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(tot_q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(tot_q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(tot_log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(tot_policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(tot_policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = tot_alpha
                self.eval_statistics['Alpha Loss'] = tot_alpha_loss
                
        self._n_train_steps_total += 1
Exemple #29
0
    def _do_reward_training(self, epoch):
        '''
            Train the discriminator
        '''
        self.disc_optimizer.zero_grad()

        expert_batch = self.get_disc_training_batch(self.disc_optim_batch_size,
                                                    True)
        policy_batch = self.get_disc_training_batch(self.disc_optim_batch_size,
                                                    False)

        expert_obs = expert_batch['observations']
        policy_obs = policy_batch['observations']
        if self.wrap_absorbing:
            expert_obs = torch.cat(
                [expert_obs, expert_batch['absorbing'][:, 0:1]], dim=-1)
            policy_obs = torch.cat(
                [policy_obs, policy_batch['absorbing'][:, 0:1]], dim=-1)
        if not self.state_only:
            expert_actions = expert_batch['actions']
            policy_actions = policy_batch['actions']

        if self.use_disc_input_noise:
            noise_scale = linear_schedule(epoch,
                                          self.disc_input_noise_scale_start,
                                          self.disc_input_noise_scale_end,
                                          self.epochs_till_end_scale)
            if noise_scale > 0.0:
                expert_obs = expert_obs + noise_scale * Variable(
                    torch.randn(expert_obs.size()))
                if not self.state_only:
                    expert_actions = expert_actions + noise_scale * Variable(
                        torch.randn(expert_actions.size()))

                policy_obs = policy_obs + noise_scale * Variable(
                    torch.randn(policy_obs.size()))
                if not self.state_only:
                    policy_actions = policy_actions + noise_scale * Variable(
                        torch.randn(policy_actions.size()))

        obs = torch.cat([expert_obs, policy_obs], dim=0)
        if not self.state_only:
            actions = torch.cat([expert_actions, policy_actions], dim=0)

        if self.state_only:
            disc_logits = self.discriminator(obs, None)
        else:
            disc_logits = self.discriminator(obs, actions)
        disc_preds = (disc_logits > 0).type(disc_logits.data.type())
        disc_ce_loss = self.bce(disc_logits, self.bce_targets)
        accuracy = (disc_preds == self.bce_targets).type(
            torch.FloatTensor).mean()

        disc_ce_loss.backward()

        ce_grad_norm = 0.0
        for name, param in self.discriminator.named_parameters():
            if param.grad is not None:
                if self.disc_grad_buffer_is_empty:
                    self.disc_grad_buffer[name] = param.grad.data.clone()
                else:
                    self.disc_grad_buffer[name].copy_(param.grad.data)

                param_norm = param.grad.data.norm(2)
                ce_grad_norm += param_norm**2
        ce_grad_norm = ce_grad_norm**0.5
        self.disc_grad_buffer_is_empty = False

        ce_clip_coef = self.disc_ce_grad_clip / (ce_grad_norm + 1e-6)
        if ce_clip_coef < 1.:
            for name, grad in self.disc_grad_buffer.items():
                grad.mul_(ce_clip_coef)

        if ce_clip_coef < 1.0: ce_grad_norm *= ce_clip_coef
        self.max_disc_ce_grad = max(ce_grad_norm, self.max_disc_ce_grad)
        self.disc_ce_grad_norm += ce_grad_norm
        self.disc_ce_grad_norm_counter += 1

        self.disc_optimizer.zero_grad()

        if self.use_grad_pen:
            eps = Variable(torch.rand(expert_obs.size(0), 1))
            if ptu.gpu_enabled(): eps = eps.cuda()

            interp_obs = eps * expert_obs + (1 - eps) * policy_obs
            interp_obs = interp_obs.detach()
            interp_obs.requires_grad = True
            if self.state_only:
                gradients = autograd.grad(
                    outputs=self.discriminator(interp_obs, None).sum(),
                    inputs=[interp_obs],
                    # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                total_grad = gradients[0]
            else:
                interp_actions = eps * expert_actions + (1 -
                                                         eps) * policy_actions
                interp_actions = interp_actions.detach()
                interp_actions.requires_grad = True
                gradients = autograd.grad(
                    outputs=self.discriminator(interp_obs,
                                               interp_actions).sum(),
                    inputs=[interp_obs, interp_actions],
                    # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                total_grad = torch.cat([gradients[0], gradients[1]], dim=1)

            # GP from Gulrajani et al.
            gradient_penalty = ((total_grad.norm(2, dim=1) - 1)**2).mean()
            disc_grad_pen_loss = gradient_penalty * self.grad_pen_weight

            # # GP from Mescheder et al.
            # gradient_penalty = (total_grad.norm(2, dim=1) ** 2).mean()
            # disc_grad_pen_loss = gradient_penalty * 0.5 * self.grad_pen_weight

            disc_grad_pen_loss.backward()

            gp_grad_norm = 0.0
            for p in list(
                    filter(lambda p: p.grad is not None,
                           self.discriminator.parameters())):
                param_norm = p.grad.data.norm(2)
                gp_grad_norm += param_norm**2
            gp_grad_norm = gp_grad_norm**0.5

            gp_clip_coef = self.disc_gp_grad_clip / (gp_grad_norm + 1e-6)
            if gp_clip_coef < 1.:
                for p in self.discriminator.parameters():
                    p.grad.data.mul_(gp_clip_coef)

            if gp_clip_coef < 1.: gp_grad_norm *= gp_clip_coef
            self.max_disc_gp_grad = max(gp_grad_norm, self.max_disc_gp_grad)
            self.disc_gp_grad_norm += gp_grad_norm
            self.disc_gp_grad_norm_counter += 1

        # now add back the gradients from the CE loss
        for name, param in self.discriminator.named_parameters():
            param.grad.data.add_(self.disc_grad_buffer[name])

        self.disc_optimizer.step()

        if self.use_target_disc:
            ptu.soft_update_from_to(self.discriminator, self.target_disc,
                                    self.soft_target_disc_tau)
        """
        Save some statistics for eval
        """
        if self.rewardf_eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.rewardf_eval_statistics = OrderedDict()

            if self.use_target_disc:
                if self.state_only:
                    target_disc_logits = self.target_disc(obs, None)
                else:
                    target_disc_logits = self.target_disc(obs, actions)
                target_disc_preds = (target_disc_logits > 0).type(
                    target_disc_logits.data.type())
                target_disc_ce_loss = self.bce(target_disc_logits,
                                               self.bce_targets)
                target_accuracy = (target_disc_preds == self.bce_targets).type(
                    torch.FloatTensor).mean()

                if self.use_grad_pen:
                    eps = Variable(torch.rand(expert_obs.size(0), 1))
                    if ptu.gpu_enabled(): eps = eps.cuda()

                    interp_obs = eps * expert_obs + (1 - eps) * policy_obs
                    interp_obs = interp_obs.detach()
                    interp_obs.requires_grad = True
                    if self.state_only:
                        target_gradients = autograd.grad(
                            outputs=self.target_disc(interp_obs, None).sum(),
                            inputs=[interp_obs],
                            # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(),
                            create_graph=True,
                            retain_graph=True,
                            only_inputs=True)
                        total_target_grad = target_gradients[0]
                    else:
                        interp_actions = eps * expert_actions + (
                            1 - eps) * policy_actions
                        interp_actions = interp_actions.detach()
                        interp_actions.requires_grad = True
                        target_gradients = autograd.grad(
                            outputs=self.target_disc(interp_obs,
                                                     interp_actions).sum(),
                            inputs=[interp_obs, interp_actions],
                            # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(),
                            create_graph=True,
                            retain_graph=True,
                            only_inputs=True)
                        total_target_grad = torch.cat(
                            [target_gradients[0], target_gradients[1]], dim=1)

                    # GP from Gulrajani et al.
                    target_gradient_penalty = ((
                        total_target_grad.norm(2, dim=1) - 1)**2).mean()

                    # # GP from Mescheder et al.
                    # target_gradient_penalty = (total_target_grad.norm(2, dim=1) ** 2).mean()

                self.rewardf_eval_statistics['Target Disc CE Loss'] = np.mean(
                    ptu.get_numpy(target_disc_ce_loss))
                self.rewardf_eval_statistics['Target Disc Acc'] = np.mean(
                    ptu.get_numpy(target_accuracy))
                self.rewardf_eval_statistics['Target Grad Pen'] = np.mean(
                    ptu.get_numpy(target_gradient_penalty))
                self.rewardf_eval_statistics['Target Grad Pen W'] = np.mean(
                    self.grad_pen_weight)

            self.rewardf_eval_statistics['Disc CE Loss'] = np.mean(
                ptu.get_numpy(disc_ce_loss))
            self.rewardf_eval_statistics['Disc Acc'] = np.mean(
                ptu.get_numpy(accuracy))
            if self.use_grad_pen:
                self.rewardf_eval_statistics['Grad Pen'] = np.mean(
                    ptu.get_numpy(gradient_penalty))
                self.rewardf_eval_statistics['Grad Pen W'] = np.mean(
                    self.grad_pen_weight)
                self.rewardf_eval_statistics[
                    'Disc Avg CE Grad Norm this epoch'] = np.mean(
                        self.disc_ce_grad_norm /
                        self.disc_ce_grad_norm_counter)
                self.rewardf_eval_statistics[
                    'Disc Max CE Grad Norm this epoch'] = np.mean(
                        self.max_disc_ce_grad)
                self.rewardf_eval_statistics[
                    'Disc Avg GP Grad Norm this epoch'] = np.mean(
                        self.disc_gp_grad_norm /
                        self.disc_gp_grad_norm_counter)
                self.rewardf_eval_statistics[
                    'Disc Max GP Grad Norm this epoch'] = np.mean(
                        self.max_disc_gp_grad)
            if self.use_disc_input_noise:
                self.rewardf_eval_statistics[
                    'Disc Input Noise Scale'] = noise_scale

            self.max_disc_ce_grad = 0.0
            self.disc_ce_grad_norm = 0.0
            self.disc_ce_grad_norm_counter = 0.0
            self.max_disc_gp_grad = 0.0
            self.disc_gp_grad_norm = 0.0
            self.disc_gp_grad_norm_counter = 0.0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        skills = batch['skills']
        """
        MI estimator btw prioceptive and extrioceptive sensors
        """
        prio_obs = obs[:, :self.prio_extrio_bound]
        extrio_obs = obs[:, self.prio_extrio_bound:]

        mi_btw_states = estimate_mutual_information(
            "smile",
            prio_obs,
            extrio_obs,
            critic_fn=self.mi_estimator,
            clip=self.smile_clip)
        mi_loss = -mi_btw_states
        """
        DF Loss and Intrinsic Reward
        """
        z_hat = torch.argmax(skills, dim=1)
        d_pred = self.df(next_obs)
        d_pred_log_softmax = F.log_softmax(d_pred, 1)
        _, pred_z = torch.max(d_pred_log_softmax, dim=1, keepdim=True)
        skill_mi_rewards = d_pred_log_softmax[torch.arange(d_pred.shape[0]),
                                              z_hat] - math.log(
                                                  1 / self.policy.skill_dim)
        df_loss = self.df_criterion(d_pred, z_hat)

        rewards = skill_mi_rewards.reshape(-1, 1) + mi_btw_states.reshape(
            -1, 1)  #+ rewards
        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            skill_vec=skills,
            reparameterize=True,
            return_log_prob=True,
        )
        obs_skills = torch.cat((obs, skills), dim=1)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs_skills, new_obs_actions),
            self.qf2(obs_skills, new_obs_actions),
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs_skills, actions)
        q2_pred = self.qf2(obs_skills, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs,
            skill_vec=skills,
            reparameterize=True,
            return_log_prob=True,
        )
        next_obs_skills = torch.cat((next_obs, skills), dim=1)
        target_q_values = torch.min(
            self.target_qf1(next_obs_skills, new_next_actions),
            self.target_qf2(next_obs_skills, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Update networks
        """
        self.df_optimizer.zero_grad()
        df_loss.backward()
        self.df_optimizer.step()

        self.mi_optimizer.zero_grad()
        mi_loss.backward()
        self.mi_optimizer.step()

        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()
        self.policy_optimizer.zero_grad()

        qf1_loss.backward()
        qf2_loss.backward()
        policy_loss.backward()

        self.qf2_optimizer.step()
        self.qf1_optimizer.step()
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        df_accuracy = torch.sum(
            torch.eq(z_hat,
                     pred_z.reshape(1,
                                    list(
                                        pred_z.size())[0])[0])).float() / list(
                                            pred_z.size())[0]

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['MI btw States Rewards'] = np.mean(
                ptu.get_numpy(mi_btw_states))
            self.eval_statistics['MI btw Skill Rewards'] = np.mean(
                ptu.get_numpy(skill_mi_rewards))
            self.eval_statistics['Sum Rewards'] = np.mean(
                ptu.get_numpy(rewards))
            self.eval_statistics['DF Loss'] = np.mean(ptu.get_numpy(df_loss))
            self.eval_statistics['DF Accuracy'] = np.mean(
                ptu.get_numpy(df_accuracy))
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'D Predictions',
                    ptu.get_numpy(pred_z),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1