Ejemplo n.º 1
0
 def update_target_networks(self):
     ptu.soft_update_from_to(
         self.qf1, self.target_qf1, self.soft_target_tau
     )
     ptu.soft_update_from_to(
         self.qf2, self.target_qf2, self.soft_target_tau
     )
Ejemplo n.º 2
0
 def _update_target_networks(self):
     if self.use_soft_update:
         ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
         ptu.soft_update_from_to(self.gcm, self.target_gcm, self.tau)
     else:
         if self._n_env_steps_total % self.target_hard_update_period == 0:
             ptu.copy_model_params_from_to(self.gcm, self.target_gcm)
             ptu.copy_model_params_from_to(self.policy, self.target_policy)
Ejemplo n.º 3
0
    def _do_training(self):
        batch = self.get_batch()
        """
        Optimize Critic/Actor.
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        _, _, v_pred = self.target_policy(next_obs, None)
        y_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * v_pred
        y_target = y_target.detach()
        mu, y_pred, v = self.policy(obs, actions)
        policy_loss = self.policy_criterion(y_pred, y_target)

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        """
        Update Target Networks
        """
        if self.use_soft_update:
            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
        else:
            if self._n_train_steps_total % self.target_hard_update_period == 0:
                ptu.copy_model_params_from_to(self.policy, self.target_policy)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy v',
                    ptu.get_numpy(v),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(mu),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Y targets',
                    ptu.get_numpy(y_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Y predictions',
                    ptu.get_numpy(y_pred),
                ))
Ejemplo n.º 4
0
 def _do_training(self, n_steps_total):
     raw_subtraj_batch, start_indices = (
         self.replay_buffer.train_replay_buffer.random_subtrajectories(
             self.num_subtrajs_per_batch
         )
     )
     subtraj_batch = create_torch_subtraj_batch(raw_subtraj_batch)
     if self.save_memory_gradients:
         subtraj_batch['memories'].requires_grad = True
     self.train_critic(subtraj_batch)
     self.train_policy(subtraj_batch, start_indices)
     if self.use_soft_update:
         ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
         ptu.soft_update_from_to(self.qf, self.target_qf, self.tau)
     else:
         if n_steps_total % self.target_hard_update_period == 0:
             ptu.copy_model_params_from_to(self.qf, self.target_qf)
             ptu.copy_model_params_from_to(self.policy, self.target_policy)
Ejemplo n.º 5
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Compute loss
        """

        best_action_idxs = self.qf(next_obs).max(1, keepdim=True)[1]
        target_q_values = self.target_qf(next_obs).gather(
            1, best_action_idxs).detach()
        y_target = rewards + (1. - terminals) * self.discount * target_q_values
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        qf_loss = self.qf_criterion(y_pred, y_target)
        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()
        """
        Soft target network updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf, self.target_qf,
                                    self.soft_target_tau)
        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Y Predictions',
                    ptu.get_numpy(y_pred),
                ))
Ejemplo n.º 6
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']

        q1_pred = self.qf1(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_pred = self.qf2(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        # Make sure policy accounts for squashing functions like tanh correctly!
        policy_outputs = self.policy(obs,
                                     goals,
                                     num_steps_left,
                                     reparameterize=self.train_policy_with_reparameterization,
                                     return_log_prob=True)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]
        if not self.dense_rewards and not self.dense_log_pi:
            log_pi = log_pi * terminals

        """
        QF Loss
        """
        target_v_values = self.target_vf(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left-1,
        )
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_v_values
        q_target = q_target.detach()
        bellman_errors_1 = (q1_pred - q_target) ** 2
        bellman_errors_2 = (q2_pred - q_target) ** 2
        qf1_loss = bellman_errors_1.mean()
        qf2_loss = bellman_errors_2.mean()

        if self.use_automatic_entropy_tuning:
            """
            Alpha Loss
            """
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = 1

        """
        VF Loss
        """
        q1_new_actions = self.qf1(
            observations=obs,
            actions=new_actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_new_actions = self.qf2(
            observations=obs,
            actions=new_actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q_new_actions = torch.min(q1_new_actions, q2_new_actions)
        v_target = q_new_actions - alpha * log_pi
        v_pred = self.vf(
            observations=obs,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        v_target = v_target.detach()
        bellman_errors = (v_pred - v_target) ** 2
        vf_loss = bellman_errors.mean()

        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        """
        Policy Loss
        """
        # paper says to do + but apparently that's a typo. Do Q - V.
        if self.train_policy_with_reparameterization:
            policy_loss = (alpha * log_pi - q_new_actions).mean()
        else:
            log_policy_target = q_new_actions - v_pred
            policy_loss = (
                log_pi * (alpha * log_pi - log_policy_target).detach()
            ).mean()
        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean ** 2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std ** 2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value ** 2).sum(dim=1).mean()
        )
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        if self._n_train_steps_total % self.policy_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.vf, self.target_vf, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'V Predictions',
                ptu.get_numpy(v_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = ptu.get_numpy(alpha)[0]
                self.eval_statistics['Alpha Loss'] = ptu.get_numpy(alpha_loss)[0]
Ejemplo n.º 7
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']
        """
        Critic operations.
        """
        next_actions = self.target_policy(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q2_values = self.target_qf2(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_pred = self.qf2(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )

        bellman_errors_1 = (q1_pred - q_target)**2
        bellman_errors_2 = (q2_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions, pre_tanh_value = self.policy(
            obs,
            goals,
            num_steps_left,
            return_preactivations=True,
        )
        policy_saturation_cost = F.relu(torch.abs(pre_tanh_value) - 20.0)
        q_output = self.qf1(
            observations=obs,
            actions=policy_actions,
            num_steps_left=num_steps_left,
            goals=goals,
        )

        policy_loss = -q_output.mean()
        if self.use_policy_saturation_cost:
            policy_loss = policy_loss + policy_saturation_cost.mean()

        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics = OrderedDict()
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman1 Errors',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman2 Errors',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Saturation Cost',
                    ptu.get_numpy(policy_saturation_cost),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                    exclude_abs=False,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action Pre-tanh',
                    ptu.get_numpy(pre_tanh_value),
                    exclude_abs=False,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Q Output',
                    ptu.get_numpy(q_output),
                    exclude_abs=False,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Noisy Next Actions',
                    ptu.get_numpy(noisy_next_actions),
                    exclude_abs=False,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Replay Buffer Actions',
                    ptu.get_numpy(actions),
                    exclude_abs=False,
                ))

            be_np = ptu.get_numpy(bellman_errors_1)
            num_steps_left_np = ptu.get_numpy(num_steps_left)

            ### tau == 0 ###
            idx_0 = np.argwhere(num_steps_left_np == 0)
            be_0 = 0
            if len(idx_0) > 0:
                be_0 = be_np[idx_0[:, 0]]
            self.eval_statistics['QF1 Loss tau=0'] = np.mean(be_0)

            ### tau == 1 ###
            idx_1 = np.argwhere(num_steps_left_np == 1)
            be_1 = 0
            if len(idx_1) > 0:
                be_1 = be_np[idx_1[:, 0]]
            self.eval_statistics['QF1 Loss tau=1'] = np.mean(be_1)

            ### tau in [2, 5) ###
            idx_2_to_5 = np.argwhere(
                np.logical_and(num_steps_left_np >= 2, num_steps_left_np < 5))
            be_2_to_5 = 0
            if len(idx_2_to_5) > 0:
                be_2_to_5 = be_np[idx_2_to_5[:, 0]]
            self.eval_statistics['QF1 Loss tau=2_to_5'] = np.mean(be_2_to_5)

            ### tau in [5, 10) ###
            idx_5_to_10 = np.argwhere(
                np.logical_and(num_steps_left_np >= 5, num_steps_left_np < 10))
            be_5_to_10 = 0
            if len(idx_5_to_10) > 0:
                be_5_to_10 = be_np[idx_5_to_10[:, 0]]
            self.eval_statistics['QF1 Loss tau=5_to_10'] = np.mean(be_5_to_10)

            ### tau in [10, max_tau] ###
            idx_10_to_end = np.argwhere(
                np.logical_and(num_steps_left_np >= 10,
                               num_steps_left_np < self.max_tau + 1))
            be_10_to_end = 0
            if len(idx_10_to_end) > 0:
                be_10_to_end = be_np[idx_10_to_end[:, 0]]
            self.eval_statistics['QF1 Loss tau=10_to_end'] = np.mean(
                be_10_to_end)
Ejemplo n.º 8
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Classifier and Policy
        """
        class_actions = self.policy(obs)
        class_prob = self.classifier(obs, actions)
        prob_target = 1 + rewards[:, -1]

        neg_log_prob = - torch.log(self.classifier(obs, class_actions))
        policy_loss = (neg_log_prob).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
Ejemplo n.º 9
0
    def _train_given_data(
        self,
        rewards,
        terminals,
        obs,
        actions,
        next_obs,
        logger_prefix="",
    ):
        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target)**2
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = -q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = -q_output.mean()

            self.eval_statistics[logger_prefix + 'QF1 Loss'] = np.mean(
                ptu.get_numpy(qf1_loss))
            self.eval_statistics[logger_prefix + 'QF2 Loss'] = np.mean(
                ptu.get_numpy(qf2_loss))
            self.eval_statistics[logger_prefix + 'Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Bellman Errors 1',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Bellman Errors 2',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
Ejemplo n.º 10
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']
        """
        Critic operations.
        """
        next_actions = self.target_policy(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q2_values = self.target_qf2(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_pred = self.qf2(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )

        bellman_errors_1 = (q1_pred - q_target)**2
        bellman_errors_2 = (q2_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions, pre_tanh_value = self.policy(
            obs,
            goals,
            num_steps_left,
            return_preactivations=True,
        )
        q_output = self.qf1(
            observations=obs,
            actions=policy_actions,
            num_steps_left=num_steps_left,
            goals=goals,
        )

        policy_loss = -q_output.mean()

        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics = OrderedDict()
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman1 Errors',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman2 Errors',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
Ejemplo n.º 11
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        obs = obs.reshape((obs.shape[0], ) + (3, 48, 48))
        if not self.discrete:
            actions = batch['actions']
        else:
            actions = batch['actions'].argmax(dim=-1)
        next_obs = batch['next_observations']
        next_obs = next_obs.reshape((next_obs.shape[0], ) + (3, 48, 48))
        """
        Policy and Alpha Loss
        """
        if self.discrete:
            new_obs_actions, pi_probs, log_pi, entropies = self.policy(
                obs, None, return_log_prob=True)
            new_next_actions, pi_next_probs, new_log_pi, next_entropies = self.policy(
                next_obs, None, return_log_prob=True)
            q_vector = self.qf1.q_vector(obs)
            q2_vector = self.qf2.q_vector(obs)
            q_next_vector = self.qf1.q_vector(next_obs)
            q2_next_vector = self.qf2.q_vector(next_obs)
        else:
            new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
                obs,
                None,
                reparameterize=True,
                return_log_prob=True,
            )

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if not self.discrete:
            if self.num_qs == 1:
                q_new_actions = self.qf1(obs, None, new_obs_actions)
            else:
                q_new_actions = torch.min(
                    self.qf1(obs, None, new_obs_actions),
                    self.qf2(obs, None, new_obs_actions),
                )

        if self.discrete:
            target_q_values = torch.min(q_vector, q2_vector)
            policy_loss = -((target_q_values * pi_probs).sum(dim=-1) +
                            alpha * entropies).mean()
        else:
            policy_loss = (alpha * log_pi - q_new_actions).mean()

        if self._current_epoch < self.policy_eval_start:
            """Start with BC"""
            policy_log_prob = self.policy.log_prob(obs, None, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()
            # print ('Policy Loss: ', policy_loss.item())
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, None, actions)
        if self.num_qs > 1:
            q2_pred = self.qf2(obs, None, actions)

        # Make sure policy accounts for squashing functions like tanh correctly!
        if not self.discrete:
            new_next_actions, _, _, new_log_pi, *_ = self.policy(
                next_obs,
                None,
                reparameterize=True,
                return_log_prob=True,
            )
            new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
                obs,
                None,
                reparameterize=True,
                return_log_prob=True,
            )
        else:
            new_curr_actions, pi_curr_probs, new_curr_log_pi, new_curr_entropies = self.policy(
                obs, None, return_log_prob=True)

        if not self.max_q_backup:
            if not self.discrete:
                if self.num_qs == 1:
                    target_q_values = self.target_qf1(next_obs, None,
                                                      new_next_actions)
                else:
                    target_q_values = torch.min(
                        self.target_qf1(next_obs, None, new_next_actions),
                        self.target_qf2(next_obs, None, new_next_actions),
                    )
            else:
                target_q_values = torch.min(
                    (self.target_qf1.q_vector(next_obs) *
                     pi_next_probs).sum(dim=-1),
                    (self.target_qf2.q_vector(next_obs) *
                     pi_next_probs).sum(dim=-1))
                target_q_values = target_q_values.unsqueeze(-1)

            if not self.deterministic_backup:
                target_q_values = target_q_values - alpha * new_log_pi

        if self.max_q_backup:
            """when using max q backup"""
            if not self.discrete:
                next_actions_temp, _ = self._get_policy_actions(
                    next_obs, num_actions=10, network=self.policy)
                target_qf1_values = \
                self._get_tensor_values(next_obs, next_actions_temp,
                                        network=self.target_qf1).max(1)[0].view(
                    -1, 1)
                target_qf2_values = \
                self._get_tensor_values(next_obs, next_actions_temp,
                                        network=self.target_qf2).max(1)[0].view(
                    -1, 1)
                target_q_values = torch.min(
                    target_qf1_values, target_qf2_values
                )  # + torch.max(target_qf1_values, target_qf2_values) * 0.25
            else:
                target_qf1_values = \
                self.target_qf1.q_vector(next_obs).max(dim=-1)[0]
                target_qf2_values = \
                self.target_qf2.q_vector(next_obs).max(dim=-1)[0]
                target_q_values = torch.min(target_qf1_values,
                                            target_qf2_values).unsqueeze(-1)

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values

        # Only detach if we are not using Bellman residual and not otherwise
        if self._use_target_nets:
            q_target = q_target.detach()

        qf1_loss = self.qf_criterion(q1_pred, q_target)
        if self.num_qs > 1:
            qf2_loss = self.qf_criterion(q2_pred, q_target)

        if self.hinge_bellman:
            qf1_loss = self.softplus(q_target - q1_pred).mean()
            qf2_loss = self.softplus(q_target - q2_pred).mean()

        ## add min_q
        if self.with_min_q:
            if not self.discrete:
                random_actions_tensor = torch.FloatTensor(
                    q2_pred.shape[0] * self.num_random,
                    actions.shape[-1]).uniform_(-1, 1).cuda()
                curr_actions_tensor, curr_log_pis = self._get_policy_actions(
                    obs, num_actions=self.num_random, network=self.policy)
                new_curr_actions_tensor, new_log_pis = self._get_policy_actions(
                    next_obs, num_actions=self.num_random, network=self.policy)
                q1_rand = self._get_tensor_values(obs,
                                                  random_actions_tensor,
                                                  network=self.qf1)
                q2_rand = self._get_tensor_values(obs,
                                                  random_actions_tensor,
                                                  network=self.qf2)
                q1_curr_actions = self._get_tensor_values(obs,
                                                          curr_actions_tensor,
                                                          network=self.qf1)
                q2_curr_actions = self._get_tensor_values(obs,
                                                          curr_actions_tensor,
                                                          network=self.qf2)
                q1_next_actions = self._get_tensor_values(
                    obs, new_curr_actions_tensor, network=self.qf1)
                q2_next_actions = self._get_tensor_values(
                    obs, new_curr_actions_tensor, network=self.qf2)

                # q1_next_states_actions = self._get_tensor_values(next_obs, new_curr_actions_tensor, network=self.qf1)
                # q2_next_states_actions = self._get_tensor_values(next_obs, new_curr_actions_tensor, network=self.qf2)

                cat_q1 = torch.cat([
                    q1_rand,
                    q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions
                ], 1)
                cat_q2 = torch.cat([
                    q2_rand,
                    q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions
                ], 1)
                std_q1 = torch.std(cat_q1, dim=1)
                std_q2 = torch.std(cat_q2, dim=1)

                if self.min_q_version == 3:
                    # importance sammpled version
                    random_density = np.log(0.5**curr_actions_tensor.shape[-1])
                    cat_q1 = torch.cat([
                        q1_rand - random_density,
                        q1_next_actions - new_log_pis.detach(),
                        q1_curr_actions - curr_log_pis.detach()
                    ], 1)
                    cat_q2 = torch.cat([
                        q2_rand - random_density,
                        q2_next_actions - new_log_pis.detach(),
                        q2_curr_actions - curr_log_pis.detach()
                    ], 1)

                if self.min_q_version == 0:
                    min_qf1_loss = cat_q1.mean() * self.min_q_weight
                    min_qf2_loss = cat_q2.mean() * self.min_q_weight
                elif self.min_q_version == 1:
                    """Expectation under softmax distribution"""
                    softmax_dist_1 = self.softmax(
                        cat_q1 / self.temp).detach() * self.temp
                    softmax_dist_2 = self.softmax(
                        cat_q2 / self.temp).detach() * self.temp
                    min_qf1_loss = (cat_q1 *
                                    softmax_dist_1).mean() * self.min_q_weight
                    min_qf2_loss = (cat_q2 *
                                    softmax_dist_2).mean() * self.min_q_weight
                elif self.min_q_version == 2 or self.min_q_version == 3:
                    """log sum exp for the min"""
                    min_qf1_loss = torch.logsumexp(
                        cat_q1 / self.temp,
                        dim=1,
                    ).mean() * self.min_q_weight * self.temp
                    min_qf2_loss = torch.logsumexp(
                        cat_q2 / self.temp,
                        dim=1,
                    ).mean() * self.min_q_weight * self.temp

                if self.data_subtract:
                    """Subtract the log likelihood of data"""
                    min_qf1_loss = min_qf1_loss - q1_pred.mean(
                    ) * self.min_q_weight
                    min_qf2_loss = min_qf2_loss - q2_pred.mean(
                    ) * self.min_q_weight
            else:
                q1_policy = (q_vector * pi_probs).sum(dim=-1)
                q2_policy = (q2_vector * pi_probs).sum(dim=-1)
                q1_next_actions = (q_next_vector * pi_next_probs).sum(dim=-1)
                q2_next_actions = (q2_next_vector * pi_next_probs).sum(dim=-1)

                if self.min_q_version == 0:
                    min_qf1_loss = (q1_policy.mean() + q1_next_actions.mean() +
                                    q_vector.mean() + q_next_vector.mean()
                                    ).mean() * self.min_q_weight
                    min_qf2_loss = (q2_policy.mean() + q1_next_actions.mean() +
                                    q2_vector.mean() + q2_next_vector.mean()
                                    ).mean() * self.min_q_weight
                elif self.min_q_version == 1:
                    min_qf1_loss = (q_vector.mean() + q_next_vector.mean()
                                    ).mean() * self.min_q_weight
                    min_qf2_loss = (q2_vector.mean() + q2_next_vector.mean()
                                    ).mean() * self.min_q_weight
                else:
                    softmax_dist_q1 = self.softmax(
                        q_vector / self.temp).detach() * self.temp
                    softmax_dist_q2 = self.softmax(
                        q2_vector / self.temp).detach() * self.temp
                    min_qf1_loss = (q_vector *
                                    softmax_dist_q1).mean() * self.min_q_weight
                    min_qf2_loss = (q2_vector *
                                    softmax_dist_q2).mean() * self.min_q_weight

                if self.data_subtract:
                    min_qf1_loss = min_qf1_loss - q1_pred.mean(
                    ) * self.min_q_weight
                    min_qf2_loss = min_qf2_loss - q2_pred.mean(
                    ) * self.min_q_weight

                std_q1 = torch.std(q_vector, dim=-1)
                std_q2 = torch.std(q2_vector, dim=-1)
                q1_on_policy = q1_policy.mean()
                q2_on_policy = q2_policy.mean()
                q1_random = q_vector.mean()
                q2_random = q2_vector.mean()
                q1_next_actions_mean = q1_next_actions.mean()
                q2_next_actions_mean = q2_next_actions.mean()

            if self.use_projected_grad:
                min_qf1_grad = torch.autograd.grad(
                    min_qf1_loss,
                    inputs=[p for p in self.qf1.parameters()],
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                min_qf2_grad = torch.autograd.grad(
                    min_qf2_loss,
                    inputs=[p for p in self.qf2.parameters()],
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                qf1_loss_grad = torch.autograd.grad(
                    qf1_loss,
                    inputs=[p for p in self.qf1.parameters()],
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                qf2_loss_grad = torch.autograd.grad(
                    qf2_loss,
                    inputs=[p for p in self.qf2.parameters()],
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)

                # this is for the offline setting
                # qf1_total_grad = self.compute_mt_grad(qf1_loss_grad, min_qf1_grad)
                # qf2_total_grad = self.compute_mt_grad(qf2_loss_grad, min_qf2_grad)
                qf1_total_grad = self.compute_new_grad(min_qf1_grad,
                                                       qf1_loss_grad)
                qf2_total_grad = self.compute_new_grad(min_qf2_grad,
                                                       qf2_loss_grad)
            else:
                if self.with_lagrange:
                    alpha_prime = torch.clamp(self.log_alpha_prime.exp(),
                                              min=0,
                                              max=2000000.0)
                    orig_min_qf1_loss = min_qf1_loss
                    orig_min_qf2_loss = min_qf2_loss
                    min_qf1_loss = alpha_prime * (min_qf1_loss -
                                                  self.target_action_gap)
                    min_qf2_loss = alpha_prime * (min_qf2_loss -
                                                  self.target_action_gap)
                    self.alpha_prime_optimizer.zero_grad()
                    alpha_prime_loss = -0.5 * (min_qf1_loss + min_qf2_loss)
                    alpha_prime_loss.backward(retain_graph=True)
                    self.alpha_prime_optimizer.step()
                qf1_loss = qf1_loss + min_qf1_loss
                qf2_loss = qf2_loss + min_qf2_loss
        """
        Update networks
        """
        # Update the Q-functions iff
        self._num_q_update_steps += 1
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        if self.with_min_q and self.use_projected_grad:
            for (p, proj_grad) in zip(self.qf1.parameters(), qf1_total_grad):
                p.grad.data = proj_grad
        self.qf1_optimizer.step()

        if self.num_qs > 1:
            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            if self.with_min_q and self.use_projected_grad:
                for (p, proj_grad) in zip(self.qf2.parameters(),
                                          qf2_total_grad):
                    p.grad.data = proj_grad
            self.qf2_optimizer.step()

        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._use_target_nets:
            if self._n_train_steps_total % self.target_update_period == 0:
                ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                        self.soft_target_tau)
                if self.num_qs > 1:
                    ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                            self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            if not self.discrete:
                policy_loss = (log_pi - q_new_actions).mean()
            else:
                target_q_values = torch.min(q_vector, q2_vector)
                policy_loss = -((target_q_values * pi_probs).sum(dim=-1) +
                                alpha * entropies).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            if self.num_qs > 1:
                self.eval_statistics['QF2 Loss'] = np.mean(
                    ptu.get_numpy(qf2_loss))

            if self.with_min_q and not self.discrete:
                self.eval_statistics['Std QF1 values'] = np.mean(
                    ptu.get_numpy(std_q1))
                self.eval_statistics['Std QF2 values'] = np.mean(
                    ptu.get_numpy(std_q2))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 in-distribution values',
                        ptu.get_numpy(q1_curr_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 in-distribution values',
                        ptu.get_numpy(q2_curr_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 random values',
                        ptu.get_numpy(q1_rand),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 random values',
                        ptu.get_numpy(q2_rand),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 next_actions values',
                        ptu.get_numpy(q1_next_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 next_actions values',
                        ptu.get_numpy(q2_next_actions),
                    ))
            elif self.with_min_q and self.discrete:
                self.eval_statistics['Std QF1 values'] = np.mean(
                    ptu.get_numpy(std_q1))
                self.eval_statistics['Std QF2 values'] = np.mean(
                    ptu.get_numpy(std_q2))
                self.eval_statistics['QF1 on policy average'] = np.mean(
                    ptu.get_numpy(q1_on_policy))
                self.eval_statistics['QF2 on policy average'] = np.mean(
                    ptu.get_numpy(q2_on_policy))
                self.eval_statistics['QF1 random average'] = np.mean(
                    ptu.get_numpy(q1_random))
                self.eval_statistics['QF2 random average'] = np.mean(
                    ptu.get_numpy(q2_random))
                self.eval_statistics[
                    'QF1 next_actions_mean average'] = np.mean(
                        ptu.get_numpy(q1_next_actions_mean))
                self.eval_statistics[
                    'QF2 next_actions_mean average'] = np.mean(
                        ptu.get_numpy(q2_next_actions_mean))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics[
                'Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            if self.num_qs > 1:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Q2 Predictions',
                        ptu.get_numpy(q2_pred),
                    ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            if not self.discrete:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy mu',
                        ptu.get_numpy(policy_mean),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy log std',
                        ptu.get_numpy(policy_log_std),
                    ))
            else:
                self.eval_statistics['Policy entropy'] = ptu.get_numpy(
                    entropies).mean()

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.with_lagrange:
                self.eval_statistics['Alpha Prime'] = alpha_prime.item()
                self.eval_statistics[
                    'Alpha Prime Loss'] = alpha_prime_loss.item()
                self.eval_statistics['Min Q1 Loss'] = orig_min_qf1_loss.item()
                self.eval_statistics['Min Q2 Loss'] = orig_min_qf2_loss.item()

        self._n_train_steps_total += 1
Ejemplo n.º 12
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)

        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, entropy, policy_std, mean_action_log_prob, pretanh_value, dist = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v_pi = self.qf1(obs, policy_mean)
        else:
            v_pi = self.qf1(obs, new_obs_actions)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        if self.awr_loss_type == "mse":
            policy_logpp = -(policy_mean - actions) ** 2
        else:
            policy_logpp = dist.log_prob(u)
            policy_logpp = policy_logpp.sum(dim=1, keepdim=True)

        advantage = q_adv - v_pi

        if self.weight_loss and weights is None:
            if self.use_automatic_beta_tuning:
                _, _, _, _, _, _, _, _, buffer_dist = self.buffer_policy(
                    obs, reparameterize=True, return_log_prob=True,
                )
                beta = self.log_beta.exp()
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
                beta_loss = -1*(beta*(kldiv-self.beta_epsilon).detach()).mean()

                self.beta_optimizer.zero_grad()
                beta_loss.backward()
                self.beta_optimizer.step()
            else:
                beta = self.beta_schedule.get_value(self._n_train_steps_total)
            weights = F.softmax(advantage / beta, dim=0)

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (-policy_logpp * len(weights)*weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (-policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (-q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if self.train_bc_on_rl_buffer:
            buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(self.replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0 :
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()



        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Advantage Weights',
                ptu.get_numpy(weights),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE": ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE": ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss": ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss": ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(self.replay_buffer,
                                                                                       self.buffer_policy)
                _, _, _, _, _, _, _, _, buffer_dist = self.buffer_policy(
                    obs, reparameterize=True, return_log_prob=True,
                )

                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                self.eval_statistics.update({
                    "buffer_policy/Train Logprob Loss": ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Logprob Loss": ptu.get_numpy(test_logp_loss),
                    "buffer_policy/Train MSE": ptu.get_numpy(buffer_train_mse_loss),
                    "buffer_policy/Test MSE": ptu.get_numpy(test_mse_loss),
                    "buffer_policy/train_policy_loss": ptu.get_numpy(buffer_policy_loss),
                    "buffer_policy/test_policy_loss": ptu.get_numpy(test_policy_loss),
                    "buffer_policy/kl_div":ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()),
                })

        self._n_train_steps_total += 1
Ejemplo n.º 13
0
    def train_from_torch(self, batch):
        logger.push_tabular_prefix("train_q/")
        self.eval_statistics = dict()
        self._need_to_update_eval_statistics = True

        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target)**2
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)

            if self.demo_train_buffer._size >= self.bc_batch_size:
                if self.use_demo_awr:
                    train_batch = self.get_batch_from_buffer(
                        self.demo_train_buffer)
                    train_o = train_batch["observations"]
                    train_u = train_batch["actions"]
                    if self.goal_conditioned:
                        train_g = train_batch["resampled_goals"]
                        train_o = torch.cat((train_o, train_g), dim=1)
                    train_pred_u = self.policy(train_o)
                    train_error = (train_pred_u - train_u)**2
                    train_bc_loss = train_error.mean()

                    policy_q_output_demo_state = self.qf1(
                        train_o, train_pred_u)
                    demo_q_output = self.qf1(train_o, train_u)

                    advantage = demo_q_output - policy_q_output_demo_state
                    self.eval_statistics['Train BC Loss'] = np.mean(
                        ptu.get_numpy(train_bc_loss))

                    if self.awr_policy_update:
                        train_bc_loss = (train_error * torch.exp(
                            (advantage) * self.demo_beta))
                        self.eval_statistics['Advantage'] = np.mean(
                            ptu.get_numpy(advantage))

                    if self._n_train_steps_total < self.max_steps_till_train_rl:
                        rl_weight = 0
                    else:
                        rl_weight = self.rl_weight

                    policy_loss = -rl_weight * q_output.mean(
                    ) + self.bc_weight * train_bc_loss.mean()
                else:
                    train_batch = self.get_batch_from_buffer(
                        self.demo_train_buffer)

                    train_o = train_batch["observations"]
                    # train_pred_u = self.policy(train_o)
                    if self.goal_conditioned:
                        train_g = train_batch["resampled_goals"]
                        train_o = torch.cat((train_o, train_g), dim=1)
                    train_pred_u = self.policy(train_o)
                    train_u = train_batch["actions"]
                    train_error = (train_pred_u - train_u)**2
                    train_bc_loss = train_error.mean()

                    # Advantage-weighted regression
                    policy_error = (policy_actions - actions)**2
                    policy_error = policy_error.mean(dim=1)
                    advantage = q1_pred - q_output
                    weights = F.softmax((advantage / self.beta)[:, 0])

                    if self.awr_policy_update:
                        policy_loss = self.rl_weight * (
                            policy_error * weights.detach() *
                            self.bc_batch_size).mean()
                    else:
                        policy_loss = -self.rl_weight * q_output.mean(
                        ) + self.bc_weight * train_bc_loss.mean()

                    self.eval_statistics.update(
                        create_stats_ordered_dict(
                            'Advantage Weights',
                            ptu.get_numpy(weights),
                        ))

                self.eval_statistics['BC Loss'] = np.mean(
                    ptu.get_numpy(train_bc_loss))

            else:  # Normal TD3 update
                policy_loss = -self.rl_weight * q_output.mean()

            if self.update_policy and not (self.rl_weight == 0
                                           and self.bc_weight == 0):
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()

            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))

        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = -q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 1',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 2',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))

            if self.demo_test_buffer._size >= self.bc_batch_size:
                train_batch = self.get_batch_from_buffer(
                    self.demo_train_buffer)
                train_u = train_batch["actions"]
                train_o = train_batch["observations"]
                if self.goal_conditioned:
                    train_g = train_batch["resampled_goals"]
                    train_o = torch.cat((train_o, train_g), dim=1)
                train_pred_u = self.policy(train_o)
                train_error = (train_pred_u - train_u)**2
                train_bc_loss = train_error

                policy_q_output_demo_state = self.qf1(train_o, train_pred_u)
                demo_q_output = self.qf1(train_o, train_u)

                train_advantage = demo_q_output - policy_q_output_demo_state

                test_batch = self.get_batch_from_buffer(self.demo_test_buffer)
                test_o = test_batch["observations"]
                test_u = test_batch["actions"]
                if self.goal_conditioned:
                    test_g = test_batch["resampled_goals"]
                    test_o = torch.cat((test_o, test_g), dim=1)
                test_pred_u = self.policy(test_o)
                test_error = (test_pred_u - test_u)**2
                test_bc_loss = test_error

                policy_q_output_demo_state = self.qf1(test_o, test_pred_u)
                demo_q_output = self.qf1(test_o, test_u)
                test_advantage = demo_q_output - policy_q_output_demo_state

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Train BC Loss',
                        ptu.get_numpy(train_bc_loss),
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Train Demo Advantage',
                        ptu.get_numpy(train_advantage),
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Test BC Loss',
                        ptu.get_numpy(test_bc_loss),
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Test Demo Advantage',
                        ptu.get_numpy(test_advantage),
                    ))

        self._n_train_steps_total += 1

        logger.pop_tabular_prefix()
Ejemplo n.º 14
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        if self._positive_reward:
            rewards += 2.5  # Make rewards positive
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Behavior clone a policy
        """
        recon, mean, std = self.vae(obs, actions)
        recon_loss = self.qf_criterion(recon, actions)
        kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                          std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * kl_loss

        self.vae_optimizer.zero_grad()
        vae_loss.backward()
        self.vae_optimizer.step()
        """
        Critic Training
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            # Duplicate state 10 times (10 is a hyperparameter chosen by BCQ)
            state_rep = next_obs.unsqueeze(1).repeat(1, 10, 1).view(
                next_obs.shape[0] * 10, next_obs.shape[1])

            # Compute value of perturbed actions sampled from the VAE
            action_rep = self.policy(state_rep)[0]
            target_qf1 = self.target_qf1(state_rep, action_rep)
            target_qf2 = self.target_qf2(state_rep, action_rep)

            # Soft Clipped Double Q-learning
            target_Q = 0.75 * torch.min(target_qf1,
                                        target_qf2) + 0.25 * torch.max(
                                            target_qf1, target_qf2)
            target_Q = target_Q.view(next_obs.shape[0],
                                     -1).max(1)[0].view(-1, 1)
            target_Q = self.reward_scale * rewards + (
                1.0 - terminals) * self.discount * target_Q

        qf1_pred = self.qf1(obs, actions)
        qf2_pred = self.qf2(obs, actions)

        qf1_loss = (qf1_pred - target_Q.detach()).pow(2).mean()
        qf2_loss = (qf2_pred - target_Q.detach()).pow(2).mean()
        """
        Actor Training
        """
        sampled_actions, raw_sampled_actions = self.vae.decode_multiple(
            obs, num_decode=self.num_samples_mmd_match)
        actor_samples, _, _, _, _, _, _, raw_actor_actions, _ = self.policy(
            obs.unsqueeze(1).repeat(1, self.num_samples_mmd_match,
                                    1).view(-1, obs.shape[1]),
            return_log_prob=True)
        actor_samples = actor_samples.view(obs.shape[0],
                                           self.num_samples_mmd_match,
                                           actions.shape[1])
        raw_actor_actions = raw_actor_actions.view(obs.shape[0],
                                                   self.num_samples_mmd_match,
                                                   actions.shape[1])

        if self._use_adv_weighting:
            # import ipdb; ipdb.set_trace()
            qf1_orig = self.qf1(
                obs.unsqueeze(1).repeat(1, self.num_samples_mmd_match,
                                        1).view(-1, obs.shape[1]),
                sampled_actions.view(-1, actions.shape[1]))

            # Get the target_q for next state
            state_rep = next_obs.unsqueeze(1).repeat(1, 10, 1).view(
                next_obs.shape[0] * 10, next_obs.shape[1])
            action_rep = self.policy(state_rep)[0]
            target_q = torch.min(
                self.qf1(state_rep, action_rep),
                self.qf1(state_rep, action_rep),
            )
            target_q = target_q.view(next_obs.shape[0],
                                     -1).max(1)[0].view(-1, 1)
            target_q = self.reward_scale * rewards + (
                1.0 - terminals) * self.discount * target_q

            adv_mmd = (
                qf1_orig.view(obs.shape[0], self.num_samples_mmd_match, 1) -
                target_q.unsqueeze(1))
            adv_mmd_not_clipped = adv_mmd
            adv_mmd = adv_mmd.exp().clamp_(max=10.0, min=1.0)
            adv_mmd = adv_mmd.detach()

        if self.kernel_choice == 'laplacian':
            if self._use_adv_weighting:
                mmd_loss = self.adv_mmd_loss_laplacian(raw_sampled_actions,
                                                       raw_actor_actions,
                                                       adv_mmd,
                                                       sigma=self.mmd_sigma)
            else:
                mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions,
                                                   raw_actor_actions,
                                                   sigma=self.mmd_sigma)
        elif self.kernel_choice == 'gaussian':
            if self._use_adv_weighting:
                mmd_loss = self.adv_mmd_loss_gaussian(raw_sampled_actions,
                                                      raw_actor_actions,
                                                      adv_mmd,
                                                      sigma=self.mmd_sigma)
            else:
                mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions,
                                                  raw_actor_actions,
                                                  sigma=self.mmd_sigma)

        action_divergence = ((sampled_actions - actor_samples)**2).sum(-1)
        raw_action_divergence = ((raw_sampled_actions -
                                  raw_actor_actions)**2).sum(-1)

        q_val1 = self.qf1(obs, actor_samples[:, 0, :])
        q_val2 = self.qf2(obs, actor_samples[:, 0, :])

        if self.policy_update_style == '0':
            policy_loss = torch.min(q_val1, q_val2)[:, 0]
        elif self.policy_update_style == '1':
            policy_loss = 0.5 * (q_val1 + q_val2)[:, 0]

        if self._n_train_steps_total >= self._bc_pretrain_steps:
            # Now we can update the policy
            if self.mode == 'auto':
                policy_loss = (-policy_loss + self.log_alpha.exp() *
                               (mmd_loss - self.target_mmd_thresh)).mean()
            else:
                policy_loss = (-policy_loss + 100 * mmd_loss).mean()
        else:
            if self.mode == 'auto':
                policy_loss = (self.log_alpha.exp() *
                               (mmd_loss - self.target_mmd_thresh)).mean()
            else:
                policy_loss = 100 * mmd_loss.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        if self.mode == 'auto':
            policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()

        if self.mode == 'auto':
            self.alpha_optimizer.zero_grad()
            (-policy_loss).backward()
            self.alpha_optimizer.step()
            self.log_alpha.data.clamp_(min=-5.0, max=10.0)
        """
        Update networks
        """
        if self._use_target_nets:
            if self._target_update_method == 'default':
                if self._n_train_steps_total % self.target_update_period == 0:
                    ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                            self.soft_target_tau)
                    ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                            self.soft_target_tau)
        """
        Some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics[
                'Num Policy Updates'] = self._num_policy_update_steps
            if (self._with_gradient_penalty_v1
                    or self._with_gradient_penalty_v2) and (
                        self._current_epoch >
                        self._start_epoch_grad_penalty + 1):
                self.eval_statistics['Grad QF1 Loss'] = np.mean(
                    ptu.get_numpy(grad_qf1_square) * self._grad_coefficient_q)
                self.eval_statistics['Grad QF2 Loss'] = np.mean(
                    ptu.get_numpy(grad_qf2_square) * self._grad_coefficient_q)
                self.eval_statistics['Grad Policy Loss'] = np.mean(
                    ptu.get_numpy(grad_square) * self._grad_coefficient_policy)
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(qf1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(qf2_pred),
                ))
            if self._use_adv_weighting:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Adv MMD', ptu.get_numpy(adv_mmd_not_clipped)))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(target_Q),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict('MMD Loss', ptu.get_numpy(mmd_loss)))
            self.eval_statistics.update(
                create_stats_ordered_dict('Action Divergence',
                                          ptu.get_numpy(action_divergence)))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Raw Action Divergence',
                    ptu.get_numpy(raw_action_divergence)))
            if self.mode == 'auto':
                self.eval_statistics['Alpha'] = self.log_alpha.exp().item()

        self._n_train_steps_total += 1