Exemple #1
0
 def get_diagnostics(self):
     path_lens = [len(path['actions']) for path in self._epoch_paths]
     stats = OrderedDict([
         ('num steps total', self._num_steps_total),
         ('num paths total', self._num_paths_total),
     ])
     stats.update(
         create_stats_ordered_dict(
             "path length",
             path_lens,
             always_show_all_stats=True,
         ))
     return stats
Exemple #2
0
    def log_diagnostics(self, paths, logger=default_logger):
        super().log_diagnostics(paths)
        MultitaskEnv.log_diagnostics(self, paths)

        statistics = OrderedDict()
        for stat_name in [
            'pos_error',
            'vel_error',
            'weighted_pos_error',
            'weighted_vel_error',
        ]:
            stat = get_stat_in_paths(paths, 'env_infos', stat_name)
            statistics.update(create_stats_ordered_dict(
                '{}'.format(stat_name),
                stat,
                always_show_all_stats=True,
            ))
            statistics.update(create_stats_ordered_dict(
                'Final {}'.format(stat_name),
                [s[-1] for s in stat],
                always_show_all_stats=True,
            ))
        weighted_error = (
            get_stat_in_paths(paths, 'env_infos', 'weighted_pos_error')
            + get_stat_in_paths(paths, 'env_infos', 'weighted_vel_error')
        )
        statistics.update(create_stats_ordered_dict(
            "Weighted Error",
            weighted_error,
            always_show_all_stats=True,
        ))
        statistics.update(create_stats_ordered_dict(
            "Final Weighted Error",
            [s[-1] for s in weighted_error],
            always_show_all_stats=True,
        ))

        for key, value in statistics.items():
            logger.record_tabular(key, value)
    def debug_statistics(self):
        """
        Given an image $$x$$, samples a bunch of latents from the prior
        $$z_i$$ and decode them $$\hat x_i$$.
        Compare this to $$\hat x$$, the reconstruction of $$x$$.
        Ideally
         - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE
           isn’t ignoring the latent)
         - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for
           coverage)
        """
        debug_batch_size = 64
        data = self.get_batch(train=False)
        reconstructions, _, _ = self.model(data)
        img = data[0]
        recon_mse = ((reconstructions[0] - img) ** 2).mean().view(-1)
        img_repeated = img.expand((debug_batch_size, img.shape[0]))

        samples = ptu.randn(debug_batch_size, self.representation_size)
        random_imgs, _ = self.model.decode(samples)
        random_mses = (random_imgs - img_repeated) ** 2
        mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse)
        stats = create_stats_ordered_dict(
            'debug/MSE improvement over random',
            mse_improvement,
        )
        stats.update(create_stats_ordered_dict(
            'debug/MSE of random decoding',
            ptu.get_numpy(random_mses),
        ))
        stats['debug/MSE of reconstruction'] = ptu.get_numpy(
            recon_mse
        )[0]
        if self.skew_dataset:
            stats.update(create_stats_ordered_dict(
                'train weight',
                self._train_weights
            ))
        return stats
Exemple #4
0
    def train_from_torch(self, batch):
        rewards = batch["rewards"] * self.reward_scale
        terminals = batch["terminals"]
        obs = batch["observations"]
        actions = batch["actions"]
        next_obs = batch["next_observations"]
        try:
            plan_lengths = batch["plan_lengths"]
            if self.single_plan_discounting:
                plan_lengths = torch.ones_like(plan_lengths)
        except KeyError as e:
            plan_lengths = torch.ones_like(rewards)
        """
        Compute loss
        """

        target_q_values = self.target_qf(next_obs).detach().max(
            1, keepdim=True)[0]
        y_target = (rewards + (1.0 - terminals) *
                    torch.pow(self.discount, plan_lengths) * target_q_values)
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        # huber loss correction.
        if self.huber_loss:
            y_target = torch.max(y_target, y_pred.sub(1))
            y_target = torch.min(y_target, y_pred.add(1))
        qf_loss = self.qf_criterion(y_pred, y_target)
        """
        Soft target network updates
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        # for param in self.qf.parameters():  # introduced parameter clipping
        #     param.grad.data.clamp_(-1, 1)
        self.qf_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf, self.target_qf,
                                    self.soft_target_tau)
        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics["QF Loss"] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict("Y Predictions",
                                          ptu.get_numpy(y_pred)))
Exemple #5
0
 def low_train_from_torch(self, batch):
     rewards = batch['rewards']
     terminals = batch['terminals']
     obs = batch['observations']
     actions = batch['actions']
     next_obs = batch['next_observations']
     goals = batch['goals']
     # kinda an approximation since doesn't account for goal switching
     next_goals = self.setter.goal_transition(obs, goals, next_obs)
     """
     Compute loss
     """
     best_action_idxs = self.low_qf(torch.cat(
         (next_obs, next_goals), dim=1)).max(1, keepdim=True)[1]
     target_q_values = self.low_target_qf(
         torch.cat((next_obs, next_goals),
                   dim=1)).gather(1, best_action_idxs).detach()
     y_target = rewards + (1. - terminals) * self.discount * target_q_values
     y_target = y_target.detach()
     # actions is a one-hot vector
     y_pred = torch.sum(self.low_qf(torch.cat(
         (obs, goals), dim=1)) * actions,
                        dim=1,
                        keepdim=True)
     qf_loss = self.qf_criterion(y_pred, y_target)
     """
     Update networks
     """
     self.low_qf_optimizer.zero_grad()
     qf_loss.backward()
     if self.grad_clip_val is not None:
         nn.utils.clip_grad_norm_(self.low_qf.parameters(),
                                  self.grad_clip_val)
     self.low_qf_optimizer.step()
     """
     Soft target network updates
     """
     if self._n_train_steps_total % self.setter_and_target_update_period == 0:
         ptu.soft_update_from_to(self.low_qf, self.low_target_qf, self.tau)
     """
     Save some statistics for eval using just one batch.
     """
     if self._need_to_update_eval_statistics:
         self._need_to_update_eval_statistics = False
         self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
         self.eval_statistics.update(
             create_stats_ordered_dict(
                 'Y Predictions',
                 ptu.get_numpy(y_pred),
             ))
Exemple #6
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Compute loss
        """

        best_action_idxs = self.qf(next_obs).max(
            1, keepdim=True
        )[1]
        target_q_values = self.target_qf(next_obs).gather(
            1, best_action_idxs
        ).detach()
        y_target = rewards + (1. - terminals) * self.discount * target_q_values
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        qf_loss = self.qf_criterion(y_pred, y_target)

        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        """
        Soft target network updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf, self.target_qf, self.soft_target_tau
            )

        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Y Predictions',
                ptu.get_numpy(y_pred),
            ))
            
        self._n_train_steps_total += 1
Exemple #7
0
 def get_diagnostics(self):
     path_lens = [len(path["actions"]) for path in self._epoch_paths]
     stats = OrderedDict([
         ("num steps total", self._num_steps_total),
         ("num paths total", self._num_paths_total),
     ])
     stats.update(
         create_stats_ordered_dict(
             "path length",
             path_lens,
             always_show_all_stats=True,
         ))
     success = [path["rewards"][-1][0] > 0 for path in self._epoch_paths]
     stats["SuccessRate"] = sum(success) / len(success)
     return stats
Exemple #8
0
    def log_diagnostics(self, paths, logger=default_logger):
        super().log_diagnostics(paths)
        MultitaskEnv.log_diagnostics(self, paths)

        statistics = OrderedDict()
        for name_in_env_infos, name_to_log in [
            ('x_pos', 'X Position'),
            ('y_pos', 'Y Position'),
            ('dist_from_origin', 'Distance from Origin'),
            ('desired_x_pos', 'Desired X Position'),
            ('desired_y_pos', 'Desired Y Position'),
            ('desired_dist_from_origin', 'Desired Distance from Origin'),
            ('pos_error', 'Distance to goal'),
        ]:
            stat = get_stat_in_paths(paths, 'env_infos', name_in_env_infos)
            statistics.update(
                create_stats_ordered_dict(
                    name_to_log,
                    stat,
                    always_show_all_stats=True,
                    exclude_max_min=True,
                ))
        for name_in_env_infos, name_to_log in [
            ('dist_from_origin', 'Distance from Origin'),
            ('desired_dist_from_origin', 'Desired Distance from Origin'),
            ('pos_error', 'Distance to goal'),
        ]:
            stat = get_stat_in_paths(paths, 'env_infos', name_in_env_infos)
            statistics.update(
                create_stats_ordered_dict(
                    'Final {}'.format(name_to_log),
                    [s[-1] for s in stat],
                    always_show_all_stats=True,
                ))
        for key, value in statistics.items():
            logger.record_tabular(key, value)
 def get_diagnostics(self):
     if self._vae_sample_probs is None or self._vae_sample_priorities is None:
         stats = create_stats_ordered_dict(
             "VAE Sample Weights",
             np.zeros(self._size),
         )
         stats.update(
             create_stats_ordered_dict(
                 "VAE Sample Probs",
                 np.zeros(self._size),
             ))
     else:
         vae_sample_priorities = self._vae_sample_priorities[:self._size]
         vae_sample_probs = self._vae_sample_probs[:self._size]
         stats = create_stats_ordered_dict(
             "VAE Sample Weights",
             vae_sample_priorities,
         )
         stats.update(
             create_stats_ordered_dict(
                 "VAE Sample Probs",
                 vae_sample_probs,
             ))
     return stats
Exemple #10
0
 def __call__(self, paths, contexts):
     diagnostics = OrderedDict()
     for state_key in self.state_to_goal_keys_map:
         goal_key = self.state_to_goal_keys_map[state_key]
         values = []
         for i in range(len(paths)):
             state = paths[i]["observations"][-1][state_key]
             goal = contexts[i][goal_key]
             distance = np.linalg.norm(state - goal)
             values.append(distance)
         diagnostics_key = goal_key + "/final/distance"
         diagnostics.update(
             create_stats_ordered_dict(
                 diagnostics_key,
                 values,
             ))
     return diagnostics
Exemple #11
0
 def get_diagnostics(self):
     path_lens = [len(path["actions"]) for path in self._epoch_paths]
     stats = OrderedDict([
         ("num steps total", self._num_steps_total),
         ("num paths total", self._num_paths_total),
         ("num low level steps total", self._num_low_level_steps_total),
         (
             "num low level steps total true",
             self._num_low_level_steps_total_true,
         ),
     ])
     stats.update(
         create_stats_ordered_dict(
             "path length",
             path_lens,
             always_show_all_stats=True,
         ))
     return stats
Exemple #12
0
    def train_from_torch(self, batch):
        rewards = batch["rewards"] * self.reward_scale
        terminals = batch["terminals"]
        obs = batch["observations"]
        actions = batch["actions"]
        next_obs = batch["next_observations"]
        """
        Compute loss
        """

        target_q_values = self.target_qf(next_obs).detach().max(
            1, keepdim=True)[0]
        y_target = rewards + (1.0 -
                              terminals) * self.discount * target_q_values
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        qf_loss = self.qf_criterion(y_pred, y_target)
        """
        Soft target network updates
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf, self.target_qf,
                                    self.soft_target_tau)
        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics["QF Loss"] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    "Y Predictions",
                    ptu.get_numpy(y_pred),
                ))
        self._n_train_steps_total += 1
 def get_diagnostics(self):
     path_lens = [len(path['actions']) for path in self._epoch_paths]
     stats = OrderedDict([
         ('num steps total', self._num_steps_total),
         ('num paths total', self._num_paths_total),
     ])
     stats.update(
         create_stats_ordered_dict(
             "path length",
             path_lens,
             always_show_all_stats=True,
         ))
     paths_policy = [
         path for path in self._epoch_paths
         if 'expert' not in path['agent_infos'][0]
     ]
     success = [path['rewards'][-1][0] > 0 for path in paths_policy]
     stats['SuccessRate'] = sum(success) / len(success)
     stats['Expert_Supervision'] = 1 - len(paths_policy) / len(
         self._epoch_paths)
     return stats
Exemple #14
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Compute loss
        """

        target_q_values = self.target_qf(next_obs).detach().max(
            1, keepdim=True
        )[0]
        y_target = rewards + (1. - terminals) * self.discount * target_q_values
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        qf_loss = self.qf_criterion(y_pred, y_target)

        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()
        self._update_target_network()

        """
        Save some statistics for eval using just one batch.
        """
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Y Predictions',
                ptu.get_numpy(y_pred),
            ))
Exemple #15
0
    def _add_exploration_bonus(self, paths):
        paths = copy.deepcopy(paths)
        entropy_decreases = []
        with torch.no_grad():
            for path in paths:
                for i in range(len(path['observations']) - 1):
                    obs1 = path['observations'][i]
                    labels1 = torch.tensor(path['env_infos'][i]['sup_labels'])
                    valid_mask1 = ~torch.isnan(labels1)
                    entropy_1 = [
                        sup_learner.get_distribution(
                            torch_ify(obs1)[None, :]).entropy()
                        for sup_learner in self.sup_learners
                    ]
                    entropy_1 = torch.mean(torch.stack(entropy_1)[valid_mask1])

                    obs2 = path['observations'][i + 1]
                    labels2 = torch.tensor(path['env_infos'][i +
                                                             1]['sup_labels'])
                    valid_mask2 = ~torch.isnan(labels2)
                    entropy_2 = [
                        sup_learner.get_distribution(
                            torch_ify(obs2)[None, :]).entropy()
                        for sup_learner in self.sup_learners
                    ]
                    entropy_2 = torch.mean(torch.stack(entropy_2)[valid_mask2])

                    entropy_decrease = (entropy_1 - entropy_2).item()
                    entropy_decreases.append(entropy_decrease)
                    path['rewards'][
                        i] += self.exploration_bonus * entropy_decrease

        if self._need_to_update_eval_statistics:
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Entropy Decrease',
                    entropy_decreases,
                ))
        return paths
Exemple #16
0
    def get_diagnostics(self):
        path_lens = [len(path["actions"]) for path in self._epoch_paths]
        average_score = (0 if self._num_episodes == 0 else self._total_score /
                         self._num_episodes)
        epoch_score = (0 if self._epoch_episodes == 0 else self._epoch_score /
                       self._epoch_episodes)
        explored = [path["explored"][0] for path in self._epoch_paths]
        paths_explored = (0 if len(explored) == 0 else sum(explored).item() /
                          len(explored))

        stats = OrderedDict([
            ("num steps total", self._num_steps_total),
            ("num paths total", self._num_paths_total),
            ("average score", average_score),
            ("epoch score", epoch_score),
            ("plans explored", paths_explored),
        ])
        action_lengths = [(path["actions"][0].item(), len(path["actions"]))
                          for path in self._epoch_paths]
        action_lengths = [0] * 16  # TODO: fix magic number
        action_counts = [0] * 16
        for path in self._epoch_paths:
            action_lengths[path["actions"][0].item()] += len(path["actions"])
            action_counts[path["actions"][0].item()] += 1

        a = {}
        for i in range(16):
            if action_counts[i] > 0:
                action_lengths[i] = action_lengths[i] / action_counts[i]
            a[f"action {i} count"] = action_counts[i]
            a[f"action {i} length"] = action_lengths[i]
        stats.update(a)

        stats.update(
            create_stats_ordered_dict("path length",
                                      path_lens,
                                      always_show_all_stats=True))
        return stats
Exemple #17
0
    def train_from_torch(self, batch):
        rewards_n = batch['rewards'].detach()
        terminals_n = batch['terminals'].detach()
        obs_n = batch['observations'].detach()
        actions_n = batch['actions'].detach()
        next_obs_n = batch['next_observations'].detach()

        batch_size = rewards_n.shape[0]
        num_agent = rewards_n.shape[1]
        whole_obs = obs_n.view(batch_size, -1)
        whole_actions = actions_n.view(batch_size, -1)
        whole_next_obs = next_obs_n.view(batch_size, -1) 

        """
        Policy operations.
        """
        online_actions_n, online_pre_values_n, online_log_pis_n = [], [], []
        for agent in range(num_agent):
            policy_actions, info = self.policy_n[agent](
                obs_n[:,agent,:], return_info=True,
            )
            online_actions_n.append(policy_actions)
            online_pre_values_n.append(info['preactivation'])
            online_log_pis_n.append(info['log_prob'])
        k0_actions = torch.stack(online_actions_n) # num_agent x batch x a_dim
        k0_actions = k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim

        k0_inputs = torch.cat([obs_n, k0_actions],dim=-1)
        k0_contexts = self.context_graph(k0_inputs)  # batch x num_agent x c_dim

        k1_actions = self.cactor(k0_contexts, deterministic=self.deterministic_cactor_in_graph)
        k1_inputs = torch.cat([obs_n, k1_actions],dim=-1)
        k1_contexts = self.context_graph(k1_inputs)

        policy_gradients_n = []
        alpha_n = []
        for agent in range(num_agent):
            policy_actions = online_actions_n[agent]
            pre_value = online_pre_values_n[agent]
            log_pi = online_log_pis_n[agent]
            if self.pre_activation_weight > 0.:
                pre_activation_policy_loss = (
                    (pre_value**2).sum(dim=1).mean()
                )
            else:
                pre_activation_policy_loss = torch.tensor(0.).to(ptu.device) 
            if self.use_entropy_loss:
                if self.use_automatic_entropy_tuning:
                    if self.state_dependent_alpha:
                        alpha = self.log_alpha_n[agent](obs_n[:,agent,:]).exp()
                    else:
                        alpha = self.log_alpha_n[agent].exp()
                    alpha_loss = -(alpha * (log_pi + self.target_entropy).detach()).mean()
                    self.alpha_optimizer_n[agent].zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer_n[agent].step()
                    if self.state_dependent_alpha:
                        alpha = self.log_alpha_n[agent](obs_n[:,agent,:]).exp().detach()
                    else:
                        alpha = self.log_alpha_n[agent].exp().detach()
                        alpha_n.append(alpha)
                else:
                    alpha_loss = torch.tensor(0.).to(ptu.device)
                    alpha = torch.tensor(self.init_alpha).to(ptu.device)
                    alpha_n.append(alpha)
                entropy_loss = (alpha*log_pi).mean()
            else:
                entropy_loss = torch.tensor(0.).to(ptu.device)

            q_input = torch.cat([policy_actions,k1_contexts[:,agent,:]],dim=-1)
            q1_output = self.qf1(q_input)
            q2_output = self.qf2(q_input)
            q_output = torch.min(q1_output,q2_output)
            raw_policy_loss = -q_output.mean()
            policy_loss = (
                    raw_policy_loss +
                    pre_activation_policy_loss * self.pre_activation_weight +
                    entropy_loss
            )

            policy_gradients_n.append(torch.autograd.grad(policy_loss, self.policy_n[agent].parameters(),retain_graph=True))

            if self._need_to_update_eval_statistics:
                self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    policy_loss
                ))
                self.eval_statistics['Raw Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    raw_policy_loss
                ))
                self.eval_statistics['Preactivation Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    pre_activation_policy_loss
                ))
                self.eval_statistics['Entropy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    entropy_loss
                ))
                if self.use_entropy_loss:
                    if self.state_dependent_alpha:
                        self.eval_statistics.update(create_stats_ordered_dict(
                            'Alpha {}'.format(agent),
                            ptu.get_numpy(alpha),
                        ))
                    else:
                        self.eval_statistics['Alpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy(
                            alpha
                        ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy Action {}'.format(agent),
                    ptu.get_numpy(policy_actions),
                ))

        for agent in range(num_agent):
            # self.policy_optimizer_n[agent].zero_grad()
            for pid,p in enumerate(self.policy_n[agent].parameters()):
                p.grad = policy_gradients_n[agent][pid]
            self.policy_optimizer_n[agent].step()

        """
        Critic operations.
        """
        with torch.no_grad():
            next_actions_n, next_log_pis_n = [], []
            for agent in range(num_agent):
                next_actions, next_info = self.policy_n[agent](
                    next_obs_n[:,agent,:], return_info=True,
                    deterministic=self.deterministic_next_action,
                )
                next_actions_n.append(next_actions)
                next_log_pis_n.append(next_info['log_prob'])
            next_k0_actions = torch.stack(next_actions_n) # num_agent x batch x a_dim
            next_k0_actions = next_k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim

            next_k0_inputs = torch.cat([next_obs_n, next_k0_actions],dim=-1)
            next_k0_contexts = self.context_graph(next_k0_inputs)  # batch x num_agent x c_dim

            next_k1_actions = self.cactor(next_k0_contexts, deterministic=self.deterministic_cactor_in_graph)
            next_k1_inputs = torch.cat([next_obs_n, next_k1_actions],dim=-1)
            next_k1_contexts = self.context_graph(next_k1_inputs)

        buffer_inputs = torch.cat([obs_n, actions_n],dim=-1)
        buffer_contexts = self.context_graph(buffer_inputs) # batch x num_agent x c_dim
        q_inputs = torch.cat([actions_n, buffer_contexts],dim=-1)
        q1_preds_n = self.qf1(q_inputs)
        q2_preds_n = self.qf2(q_inputs)

        raw_qf1_loss_n, raw_qf2_loss_n, q_target_n = [], [], []
        for agent in range(num_agent):
            with torch.no_grad():
                next_policy_actions = next_actions_n[agent]
                next_log_pi = next_log_pis_n[agent]
                next_q_input = torch.cat([next_policy_actions,next_k1_contexts[:,agent,:]],dim=-1)

                next_target_q1_values = self.target_qf1(next_q_input)
                next_target_q2_values = self.target_qf2(next_q_input)
                next_target_q_values = torch.min(next_target_q1_values, next_target_q2_values)

                if self.use_entropy_reward:
                    if self.state_dependent_alpha:
                        next_alpha = self.log_alpha_n[agent](next_obs_n[:,agent,:]).exp()
                    else:
                        next_alpha = alpha_n[agent]
                    next_target_q_values =  next_target_q_values - next_alpha * next_log_pi

                q_target = self.reward_scale*rewards_n[:,agent,:] + (1. - terminals_n[:,agent,:]) * self.discount * next_target_q_values
                q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value)
                q_target_n.append(q_target)

            q1_pred = q1_preds_n[:,agent,:]
            raw_qf1_loss = self.qf_criterion(q1_pred, q_target)
            raw_qf1_loss_n.append(raw_qf1_loss)

            q2_pred = q2_preds_n[:,agent,:]
            raw_qf2_loss = self.qf_criterion(q2_pred, q_target)
            raw_qf2_loss_n.append(raw_qf2_loss)

            if self._need_to_update_eval_statistics:
                self.eval_statistics['QF1 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(raw_qf1_loss))
                self.eval_statistics['QF2 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(raw_qf2_loss))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q1 Predictions {}'.format(agent),
                    ptu.get_numpy(q1_pred),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q2 Predictions {}'.format(agent),
                    ptu.get_numpy(q2_pred),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q Targets {}'.format(agent),
                    ptu.get_numpy(q_target),
                ))

        if self.sum_n_loss:
            raw_qf1_loss = torch.sum(torch.stack(raw_qf1_loss_n))
            raw_qf2_loss = torch.sum(torch.stack(raw_qf2_loss_n))
        else:
            raw_qf1_loss = torch.mean(torch.stack(raw_qf1_loss_n))
            raw_qf2_loss = torch.mean(torch.stack(raw_qf2_loss_n))

        if self.negative_sampling:
            perturb_actions = actions_n.clone() # batch x agent x |A|
            batch_size, num_agent, a_dim = perturb_actions.shape
            perturb_agents = torch.randint(low=0,high=num_agent,size=(batch_size,))
            neg_actions = torch.rand(batch_size,a_dim)*2.-1. # ranged in -1 to 1
            perturb_actions[torch.arange(batch_size),perturb_agents,:] = neg_actions
                
            perturb_inputs = torch.cat([obs_n,perturb_actions],dim=-1)
            perturb_contexts = self.context_graph(perturb_inputs) # batch x num_agent x c_dim
            perturb_q_inputs = torch.cat([actions_n, perturb_contexts],dim=-1)
            perturb_q1_preds = self.qf1(perturb_q_inputs)[torch.arange(batch_size),perturb_agents,:]
            perturb_q2_preds = self.qf2(perturb_q_inputs)[torch.arange(batch_size),perturb_agents,:]
            perturb_q_targets = torch.stack(q_target_n).transpose(0,1).contiguous()[torch.arange(batch_size),perturb_agents,:]

            neg_loss1 = self.qf_criterion(perturb_q1_preds, perturb_q_targets)
            neg_loss2 = self.qf_criterion(perturb_q2_preds, perturb_q_targets)
        else:
            neg_loss1, neg_loss2 = torch.tensor(0.).to(ptu.device), torch.tensor(0.).to(ptu.device)

        if self.qf_weight_decay > 0:
            reg_loss1 = self.qf_weight_decay * sum(
                torch.sum(param ** 2)
                for param in self.qf1.regularizable_parameters()
            )

            reg_loss2 = self.qf_weight_decay * sum(
                torch.sum(param ** 2)
                for param in self.qf2.regularizable_parameters()
            )
        else:
            reg_loss1, reg_loss2 = torch.tensor(0.).to(ptu.device), torch.tensor(0.).to(ptu.device)

        qf1_loss = raw_qf1_loss + reg_loss1 + neg_loss1
        qf2_loss = raw_qf2_loss + reg_loss2 + neg_loss2

        if self._need_to_update_eval_statistics:
            self.eval_statistics['raw_qf1_loss'] = np.mean(ptu.get_numpy(raw_qf1_loss))
            self.eval_statistics['raw_qf2_loss'] = np.mean(ptu.get_numpy(raw_qf2_loss))
            self.eval_statistics['neg_qf1_loss'] = np.mean(ptu.get_numpy(neg_loss1))
            self.eval_statistics['neg_qf2_loss'] = np.mean(ptu.get_numpy(neg_loss2))
            self.eval_statistics['reg_qf2_loss'] = np.mean(ptu.get_numpy(reg_loss1))
            self.eval_statistics['reg_qf2_loss'] = np.mean(ptu.get_numpy(reg_loss2))

        self.context_graph_optimizer.zero_grad()
        cg_loss = qf1_loss+qf2_loss
        cg_loss.backward()
        self.context_graph_optimizer.step()

        """
        Central actor operations.
        """
        buffer_inputs = torch.cat([obs_n, actions_n],dim=-1)
        buffer_contexts = self.context_graph(buffer_inputs) # batch x num_agent x c_dim
        cactor_loss_n = []
        for agent in range(num_agent):
            cactor_actions, cactor_info = self.cactor(
                buffer_contexts[:,agent,:], return_info=True,
            )
            cactor_pre_value = cactor_info['preactivation']
            if self.pre_activation_weight > 0:
                pre_activation_cactor_loss = (
                    (cactor_pre_value**2).sum(dim=1).mean()
                )
            else:
                pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device)
            if self.use_cactor_entropy_loss:
                cactor_log_pi = cactor_info['log_prob']
                if self.use_automatic_entropy_tuning:
                    if self.state_dependent_alpha:
                        calpha = self.log_calpha_n[agent](whole_obs).exp()
                    else:
                        calpha = self.log_calpha_n[agent].exp()
                    calpha_loss = -(calpha * (cactor_log_pi + self.target_entropy).detach()).mean()
                    self.calpha_optimizer_n[agent].zero_grad()
                    calpha_loss.backward()
                    self.calpha_optimizer_n[agent].step()
                    if self.state_dependent_alpha:
                        calpha = self.log_calpha_n[agent](whole_obs).exp().detach()
                    else:
                        calpha = self.log_calpha_n[agent].exp().detach()
                else:
                    calpha_loss = torch.tensor(0.).to(ptu.device)
                    calpha = torch.tensor(self.init_alpha).to(ptu.device)
                cactor_entropy_loss = (calpha*cactor_log_pi).mean()
            else:
                cactor_entropy_loss = torch.tensor(0.).to(ptu.device)
            
            q_input = torch.cat([cactor_actions,buffer_contexts[:,agent,:]],dim=-1)
            q1_output = self.qf1(q_input)
            q2_output = self.qf2(q_input)
            q_output = torch.min(q1_output,q2_output)
            raw_cactor_loss = -q_output.mean()
            cactor_loss = (
                    raw_cactor_loss +
                    pre_activation_cactor_loss * self.pre_activation_weight +
                    cactor_entropy_loss
            )
            cactor_loss_n.append(cactor_loss)

            if self._need_to_update_eval_statistics:
                if self.use_cactor_entropy_loss:
                    if self.state_dependent_alpha:
                        self.eval_statistics.update(create_stats_ordered_dict(
                            'CAlpha {}'.format(agent),
                            ptu.get_numpy(calpha),
                        ))
                    else:
                        self.eval_statistics['CAlpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy(
                            calpha
                        ))
                self.eval_statistics['Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    cactor_loss
                ))
                self.eval_statistics['Raw Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    raw_cactor_loss
                ))
                self.eval_statistics['Preactivation Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    pre_activation_cactor_loss
                ))
                self.eval_statistics['Entropy Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    cactor_entropy_loss
                ))
        if self.sum_n_loss:
            cactor_loss = torch.sum(torch.stack(cactor_loss_n))
        else:
            cactor_loss = torch.mean(torch.stack(cactor_loss_n))
        self.cactor_optimizer.zero_grad()
        cactor_loss.backward()
        self.cactor_optimizer.step()
                
        self._need_to_update_eval_statistics = False
        self._update_target_networks()
        self._n_train_steps_total += 1
Exemple #18
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        context = batch['context']

        # data is (task, batch, feat)
        # obs, actions, rewards, next_obs, terms = self.sample_sac(indices)

        # run inference in networks
        action_distrib, p_z, task_z_with_grad = self.agent(
            obs,
            context,
            return_latent_posterior_and_task_z=True,
        )
        task_z_detached = task_z_with_grad.detach()
        new_actions, log_pi, pre_tanh_value = (
            action_distrib.rsample_logprob_and_pretanh())
        log_pi = log_pi.unsqueeze(1)
        policy_mean = action_distrib.mean
        policy_log_std = torch.log(action_distrib.stddev)

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        unscaled_rewards_flat = rewards.view(t * b, 1)
        rewards_flat = unscaled_rewards_flat * self.reward_scale
        terms_flat = terminals.view(t * b, 1)

        # Q and V networks
        # encoder will only get gradients from Q nets
        if self.backprop_q_loss_into_encoder:
            q1_pred = self.qf1(obs, actions, task_z_with_grad)
            q2_pred = self.qf2(obs, actions, task_z_with_grad)
        else:
            q1_pred = self.qf1(obs, actions, task_z_detached)
            q2_pred = self.qf2(obs, actions, task_z_detached)
        v_pred = self.vf(obs, task_z_detached)
        # get targets for use in V and Q updates
        with torch.no_grad():
            target_v_values = self.target_vf(next_obs, task_z_detached)
        """
        QF, Encoder, and Decoder Loss
        """
        # note: encoder/deocder do not get grads from policy or vf
        q_target = rewards_flat + (
            1. - terms_flat) * self.discount * target_v_values
        qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean(
            (q2_pred - q_target)**2)

        # KL constraint on z if probabilistic
        kl_div = kl_divergence(p_z, self.agent.latent_prior).sum()
        kl_loss = self.kl_lambda * kl_div
        if self.train_context_decoder:
            # TODO: change to use a distribution
            reward_pred = self.context_decoder(obs, actions, task_z_with_grad)
            reward_prediction_loss = ((reward_pred -
                                       unscaled_rewards_flat)**2).mean()
            context_loss = kl_loss + reward_prediction_loss
        else:
            context_loss = kl_loss
            reward_prediction_loss = ptu.zeros(1)

        if self.train_encoder_decoder:
            self.context_optimizer.zero_grad()
        if self.train_agent:
            self.qf1_optimizer.zero_grad()
            self.qf2_optimizer.zero_grad()
        context_loss.backward(retain_graph=True)
        qf_loss.backward()
        if self.train_agent:
            self.qf1_optimizer.step()
            self.qf2_optimizer.step()
        if self.train_encoder_decoder:
            self.context_optimizer.step()
        """
        VF update
        """
        min_q_new_actions = self._min_q(obs, new_actions, task_z_detached)
        v_target = min_q_new_actions - log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()
        self._update_target_network()
        """
        Policy update
        """
        # n.b. policy update includes dQ/da
        log_policy_target = min_q_new_actions
        policy_loss = (log_pi - log_policy_target).mean()

        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value**2).sum(dim=1).mean())
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # save some statistics for eval
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            # eval should set this to None.
            # this way, these statistics are only computed for one batch.
            self.eval_statistics = OrderedDict()
            if self.use_information_bottleneck:
                z_mean = np.mean(np.abs(ptu.get_numpy(p_z.mean)))
                z_sig = np.mean(ptu.get_numpy(p_z.stddev))
                self.eval_statistics['Z mean-abs train'] = z_mean
                self.eval_statistics['Z variance train'] = z_sig
                self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div)
                self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss)

            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics['task_embedding/kl_divergence'] = (
                ptu.get_numpy(kl_div))
            self.eval_statistics['task_embedding/kl_loss'] = (
                ptu.get_numpy(kl_loss))
            self.eval_statistics['task_embedding/reward_prediction_loss'] = (
                ptu.get_numpy(reward_prediction_loss))
            self.eval_statistics['task_embedding/context_loss'] = (
                ptu.get_numpy(context_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Predictions',
                    ptu.get_numpy(v_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
Exemple #19
0
    def compute_loss(
        self,
        batch,
        skip_statistics=False,
    ) -> Tuple[SACLosses, LossStatistics]:
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch["weights"]
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(-1)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(weights.detach() *
                           (self.log_alpha *
                            (log_pi + self.target_entropy).detach())).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        policy_loss = (weights.detach() *
                       (alpha * log_pi - q_new_actions)).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.unsqueeze(-1)
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = (weights.detach() *
                    ((q1_pred - q_target.detach())**2)).mean()
        qf2_loss = (weights.detach() *
                    ((q2_pred - q_target.detach())**2)).mean()
        errors = (
            ((torch.abs(q_target - q1_pred) + torch.abs(q_target - q2_pred)) /
             2) * weights).detach()
        """
        Save some statistics for eval
        """
        eval_statistics = OrderedDict()
        if not skip_statistics:
            eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            eval_statistics.update(policy_statistics)
            if self.use_automatic_entropy_tuning:
                eval_statistics['Alpha'] = alpha.item()
                eval_statistics['Alpha Loss'] = alpha_loss.item()

        loss = SACLosses(
            policy_loss=policy_loss,
            qf1_loss=qf1_loss,
            qf2_loss=qf2_loss,
            alpha_loss=alpha_loss,
        )

        return loss, eval_statistics, errors
Exemple #20
0
    def train_from_torch(self, batch):
        rewards_n = batch['rewards'].detach()
        terminals_n = batch['terminals'].detach()
        obs_n = batch['observations'].detach()
        actions_n = batch['actions'].detach()
        next_obs_n = batch['next_observations'].detach()

        batch_size = rewards_n.shape[0]
        num_agent = rewards_n.shape[1]
        whole_obs = obs_n.view(batch_size, -1)
        whole_actions = actions_n.view(batch_size, -1)
        whole_next_obs = next_obs_n.view(batch_size, -1) 

        """
        Policy operations.
        """
        online_actions_n, online_pre_values_n, online_log_pis_n = [], [], []
        for agent in range(num_agent):
            policy_actions, info = self.policy_n[agent](
                obs_n[:,agent,:], return_info=True,
            )
            online_actions_n.append(policy_actions)
            online_pre_values_n.append(info['preactivation'])
            online_log_pis_n.append(info['log_prob'])
        k0_actions = torch.stack(online_actions_n) # num_agent x batch x a_dim
        k0_actions = k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim

        k0_inputs = torch.cat([obs_n, k0_actions],dim=-1)
        k0_contexts = self.cgca(k0_inputs)
        k1_actions = self.cactor(k0_contexts, deterministic=self.deterministic_cactor_in_graph)

        policy_gradients_n = []
        alpha_n = []
        for agent in range(num_agent):
            policy_actions = online_actions_n[agent]
            pre_value = online_pre_values_n[agent]
            log_pi = online_log_pis_n[agent]
            if self.pre_activation_weight > 0.:
                pre_activation_policy_loss = (
                    (pre_value**2).sum(dim=1).mean()
                )
            else:
                pre_activation_policy_loss = torch.tensor(0.).to(ptu.device) 
            if self.use_entropy_loss:
                if self.use_automatic_entropy_tuning:
                    alpha = self.log_alpha_n[agent].exp()
                    alpha_loss = -(alpha * (log_pi + self.target_entropy).detach()).mean()
                    self.alpha_optimizer_n[agent].zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer_n[agent].step()
                    alpha = self.log_alpha_n[agent].exp().detach()
                    alpha_n.append(alpha)
                else:
                    alpha_loss = torch.tensor(0.).to(ptu.device)
                    alpha = torch.tensor(self.init_alpha).to(ptu.device)
                    alpha_n.append(alpha)
                entropy_loss = (alpha*log_pi).mean()
            else:
                entropy_loss = torch.tensor(0.).to(ptu.device)

            input_actions = k1_actions.clone()
            input_actions[:,agent,:] = policy_actions
            q1_output = self.qf1_n[agent](whole_obs, input_actions.view(batch_size, -1))
            q2_output = self.qf2_n[agent](whole_obs, input_actions.view(batch_size, -1))
            q_output = torch.min(q1_output,q2_output)
            raw_policy_loss = -q_output.mean()
            policy_loss = (
                    raw_policy_loss +
                    pre_activation_policy_loss * self.pre_activation_weight +
                    entropy_loss
            )

            policy_gradients_n.append(torch.autograd.grad(policy_loss, self.policy_n[agent].parameters(),retain_graph=True))

            if self._need_to_update_eval_statistics:
                self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    policy_loss
                ))
                self.eval_statistics['Raw Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    raw_policy_loss
                ))
                self.eval_statistics['Preactivation Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    pre_activation_policy_loss
                ))
                self.eval_statistics['Entropy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    entropy_loss
                ))
                if self.use_entropy_loss:
                    self.eval_statistics['Alpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy(
                        alpha
                    ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy Action {}'.format(agent),
                    ptu.get_numpy(policy_actions),
                ))

        for agent in range(num_agent):
            # self.policy_optimizer_n[agent].zero_grad()
            for pid,p in enumerate(self.policy_n[agent].parameters()):
                p.grad = policy_gradients_n[agent][pid]
            self.policy_optimizer_n[agent].step()

        """
        Critic operations.
        """
        with torch.no_grad():
            next_actions_n, next_log_pis_n = [], []
            for agent in range(num_agent):
                next_actions, next_info = self.policy_n[agent](
                    next_obs_n[:,agent,:], return_info=True,
                    deterministic=self.deterministic_next_action,
                )
                next_actions_n.append(next_actions)
                next_log_pis_n.append(next_info['log_prob'])
            next_k0_actions = torch.stack(next_actions_n) # num_agent x batch x a_dim
            next_k0_actions = next_k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim

            next_k0_inputs = torch.cat([next_obs_n, next_k0_actions],dim=-1)
            next_k0_contexts = self.cgca(next_k0_inputs)
            next_k1_actions = self.cactor(next_k0_contexts, deterministic=self.deterministic_cactor_in_graph)

        for agent in range(num_agent):
            with torch.no_grad():
                input_actions = next_k1_actions.clone()
                input_actions[:,agent,:] = next_actions_n[agent]
                next_target_q1_values = self.target_qf1_n[agent](
                    whole_next_obs,
                    input_actions.view(batch_size,-1),
                )
                next_target_q2_values = self.target_qf2_n[agent](
                    whole_next_obs,
                    input_actions.view(batch_size,-1),
                )
                next_target_q_values = torch.min(next_target_q1_values, next_target_q2_values)

                if self.use_entropy_reward:
                    next_alpha = alpha_n[agent]
                    next_target_q_values =  next_target_q_values - next_alpha * next_log_pis_n[agent]

                q_target = self.reward_scale*rewards_n[:,agent,:] + (1. - terminals_n[:,agent,:]) * self.discount * next_target_q_values
                q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value)

            q1_pred = self.qf1_n[agent](whole_obs, whole_actions)
            raw_qf1_loss = self.qf_criterion(q1_pred, q_target)
            if self.qf_weight_decay > 0:
                reg_loss1 = self.qf_weight_decay * sum(
                    torch.sum(param ** 2)
                    for param in self.qf1_n[agent].regularizable_parameters()
                )
                qf1_loss = raw_qf1_loss + reg_loss1
            else:
                qf1_loss = raw_qf1_loss

            q2_pred = self.qf2_n[agent](whole_obs, whole_actions)
            raw_qf2_loss = self.qf_criterion(q2_pred, q_target)
            if self.qf_weight_decay > 0:
                reg_loss2 = self.qf_weight_decay * sum(
                    torch.sum(param ** 2)
                    for param in self.qf2_n[agent].regularizable_parameters()
                )
                qf2_loss = raw_qf2_loss + reg_loss2
            else:
                qf2_loss = raw_qf2_loss

            self.qf1_optimizer_n[agent].zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer_n[agent].step()

            self.qf2_optimizer_n[agent].zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer_n[agent].step()

            if self._need_to_update_eval_statistics:
                self.eval_statistics['QF1 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(qf1_loss))
                self.eval_statistics['QF2 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(qf2_loss))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q1 Predictions {}'.format(agent),
                    ptu.get_numpy(q1_pred),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q2 Predictions {}'.format(agent),
                    ptu.get_numpy(q2_pred),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q Targets {}'.format(agent),
                    ptu.get_numpy(q_target),
                ))

        """
        Central actor operations.
        """
        buffer_inputs = torch.cat([obs_n, actions_n],dim=-1)
        buffer_contexts_ca = self.cgca(buffer_inputs)
        cactor_actions, cactor_infos = self.cactor(buffer_contexts_ca,return_info=True)
        # batch x agent_num x |A|

        cactor_loss_n = []
        for agent in range(num_agent):
            cactor_pre_value = cactor_infos['preactivation'][:,agent,:]
            if self.pre_activation_weight > 0:
                pre_activation_cactor_loss = (
                    (cactor_pre_value**2).sum(dim=1).mean()
                )
            else:
                pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device)
            if self.use_cactor_entropy_loss:
                cactor_log_pi = cactor_infos['log_prob'][:,agent,:]
                if self.use_automatic_entropy_tuning:
                    calpha = self.log_calpha_n[agent].exp()
                    calpha_loss = -(calpha * (cactor_log_pi + self.target_entropy).detach()).mean()
                    self.calpha_optimizer_n[agent].zero_grad()
                    calpha_loss.backward()
                    self.calpha_optimizer_n[agent].step()
                    calpha = self.log_calpha_n[agent].exp().detach()
                else:
                    calpha_loss = torch.tensor(0.).to(ptu.device)
                    calpha = torch.tensor(self.init_alpha).to(ptu.device)
                cactor_entropy_loss = (calpha*cactor_log_pi).mean()
            else:
                cactor_entropy_loss = torch.tensor(0.).to(ptu.device)
            
            current_actions = actions_n.clone()
            current_actions[:,agent,:] = cactor_actions[:,agent,:] 
            q1_output = self.qf1_n[agent](whole_obs, current_actions.view(batch_size, -1))
            q2_output = self.qf2_n[agent](whole_obs, current_actions.view(batch_size, -1))
            q_output = torch.min(q1_output,q2_output)
            raw_cactor_loss = -q_output.mean()
            cactor_loss = (
                    raw_cactor_loss +
                    pre_activation_cactor_loss * self.pre_activation_weight +
                    cactor_entropy_loss
            )
            cactor_loss_n.append(cactor_loss)

            if self._need_to_update_eval_statistics:
                if self.use_cactor_entropy_loss:
                    self.eval_statistics['CAlpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy(
                        calpha
                    ))
                self.eval_statistics['Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    cactor_loss
                ))
                self.eval_statistics['Raw Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    raw_cactor_loss
                ))
                self.eval_statistics['Preactivation Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    pre_activation_cactor_loss
                ))
                self.eval_statistics['Entropy Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    cactor_entropy_loss
                ))
        cactor_loss = torch.mean(torch.stack(cactor_loss_n))
        self.cactor_optimizer.zero_grad()
        cactor_loss.backward()

        cgca_grad_norm = torch.tensor(0.).to(ptu.device)
        for p in self.cgca.parameters():
            p_norm = p.grad.data.norm(2)
            cgca_grad_norm += p_norm.item() ** 2
        cgca_grad_norm = (cgca_grad_norm ** (1. / 2)).item()

        cactor_grad_norm = torch.tensor(0.).to(ptu.device)
        for p in self.cactor.parameters():
            p_norm = p.grad.data.norm(2)
            cactor_grad_norm += p_norm.item() ** 2
        cactor_grad_norm = (cactor_grad_norm ** (1. / 2)).item()

        self.cactor_optimizer.step()

        if self._need_to_update_eval_statistics:
            self.eval_statistics['CGCA Gradient'] = cgca_grad_norm
            self.eval_statistics['CActor Gradient'] = cactor_grad_norm
                
        self._need_to_update_eval_statistics = False
        self._update_target_networks()
        self._n_train_steps_total += 1
Exemple #21
0
    def train_from_torch(self, batch):
        rewards_n = batch['rewards'].detach()
        terminals_n = batch['terminals'].detach()
        obs_n = batch['observations'].detach()
        actions_n = batch['actions'].detach()
        next_obs_n = batch['next_observations'].detach()

        batch_size = rewards_n.shape[0]
        num_agent = rewards_n.shape[1]
        whole_obs = obs_n.view(batch_size, -1)
        whole_actions = actions_n.view(batch_size, -1)
        whole_next_obs = next_obs_n.view(batch_size, -1) 

        """
        Policy operations.
        """
        online_actions_n, online_pre_values_n, online_log_pis_n = [], [], []
        for agent in range(num_agent):
            policy_actions, info = self.policy_n[agent](
                obs_n[:,agent,:], return_info=True,
            )
            online_actions_n.append(policy_actions)
            online_pre_values_n.append(info['preactivation'])
            online_log_pis_n.append(info['log_prob'])
        k0_actions = torch.stack(online_actions_n) # num_agent x batch x a_dim
        k0_actions = k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim

        k0_inputs = torch.cat([obs_n, k0_actions],dim=-1)
        k0_contexts = self.cgca(k0_inputs)
        k1_actions = [self.cactor_n[agent](k0_contexts[:,agent,:],
                                         deterministic=self.deterministic_cactor_in_graph)
                        for agent in range(num_agent)]
        k1_actions = torch.stack(k1_actions).transpose(0,1).contiguous()

        k1_inputs = torch.cat([obs_n, k1_actions],dim=-1)
        k1_contexts_1 = self.cg1(k1_inputs)
        k1_contexts_2 = self.cg2(k1_inputs)

        policy_gradients_n = []
        alpha_n = []
        for agent in range(num_agent):
            policy_actions = online_actions_n[agent]
            pre_value = online_pre_values_n[agent]
            log_pi = online_log_pis_n[agent]
            if self.pre_activation_weight > 0.:
                pre_activation_policy_loss = (
                    (pre_value**2).sum(dim=1).mean()
                )
            else:
                pre_activation_policy_loss = torch.tensor(0.).to(ptu.device) 
            if self.use_entropy_loss:
                if self.use_automatic_entropy_tuning:
                    alpha = self.log_alpha_n[agent].exp()
                    alpha_loss = -(alpha * (log_pi + self.target_entropy).detach()).mean()
                    self.alpha_optimizer_n[agent].zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer_n[agent].step()
                    alpha = self.log_alpha_n[agent].exp().detach()
                    alpha_n.append(alpha)
                else:
                    alpha_loss = torch.tensor(0.).to(ptu.device)
                    alpha = torch.tensor(self.init_alpha).to(ptu.device)
                    alpha_n.append(alpha)
                entropy_loss = (alpha*log_pi).mean()
            else:
                entropy_loss = torch.tensor(0.).to(ptu.device)

            q1_input = torch.cat([policy_actions,k1_contexts_1[:,agent,:]],dim=-1)
            q1_output = self.qf1_n[agent](q1_input)
            q2_input = torch.cat([policy_actions,k1_contexts_2[:,agent,:]],dim=-1)
            q2_output = self.qf2_n[agent](q2_input)
            q_output = torch.min(q1_output,q2_output)
            raw_policy_loss = -q_output.mean()
            policy_loss = (
                    raw_policy_loss +
                    pre_activation_policy_loss * self.pre_activation_weight +
                    entropy_loss
            )

            policy_gradients_n.append(torch.autograd.grad(policy_loss, self.policy_n[agent].parameters(),retain_graph=True))

            if self._need_to_update_eval_statistics:
                self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    policy_loss
                ))
                self.eval_statistics['Raw Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    raw_policy_loss
                ))
                self.eval_statistics['Preactivation Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    pre_activation_policy_loss
                ))
                self.eval_statistics['Entropy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    entropy_loss
                ))
                if self.use_entropy_loss:
                    self.eval_statistics['Alpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy(
                        alpha
                    ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy Action {}'.format(agent),
                    ptu.get_numpy(policy_actions),
                ))

        for agent in range(num_agent):
            # self.policy_optimizer_n[agent].zero_grad()
            for pid,p in enumerate(self.policy_n[agent].parameters()):
                p.grad = policy_gradients_n[agent][pid]
            self.policy_optimizer_n[agent].step()

        """
        Critic operations.
        """
        with torch.no_grad():
            next_actions_n, next_log_pis_n = [], []
            for agent in range(num_agent):
                next_actions, next_info = self.policy_n[agent](
                    next_obs_n[:,agent,:], return_info=True,
                    deterministic=self.deterministic_next_action,
                )
                next_actions_n.append(next_actions)
                next_log_pis_n.append(next_info['log_prob'])
            next_k0_actions = torch.stack(next_actions_n) # num_agent x batch x a_dim
            next_k0_actions = next_k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim

            next_k0_inputs = torch.cat([next_obs_n, next_k0_actions],dim=-1)
            next_k0_contexts = self.cgca(next_k0_inputs)
            next_k1_actions = [self.cactor_n[agent](next_k0_contexts[:,agent,:],
                                             deterministic=self.deterministic_cactor_in_graph)
                            for agent in range(num_agent)]
            next_k1_actions = torch.stack(next_k1_actions).transpose(0,1).contiguous()

            next_k1_inputs = torch.cat([next_obs_n, next_k1_actions],dim=-1)
            next_k1_contexts_1 = self.target_cg1(next_k1_inputs)
            next_k1_contexts_2 = self.target_cg2(next_k1_inputs)

            next_q1_inputs = torch.cat([next_k0_actions,next_k1_contexts_1],dim=-1)
            next_target_q1_values = [self.target_qf1_n[agent](next_q1_inputs[:,agent,:]) for agent in range(num_agent)]
            next_target_q1_values = torch.stack(next_target_q1_values).transpose(0,1).contiguous()

            next_q2_inputs = torch.cat([next_k0_actions,next_k1_contexts_2],dim=-1)
            next_target_q2_values = [self.target_qf2_n[agent](next_q2_inputs[:,agent,:]) for agent in range(num_agent)]
            next_target_q2_values = torch.stack(next_target_q2_values).transpose(0,1).contiguous()

            next_target_q_values = torch.min(next_target_q1_values, next_target_q2_values)

            if self.use_entropy_reward:
                next_alphas = torch.stack(alpha_n)[None,:]
                next_log_pis = torch.stack(next_log_pis_n).transpose(0,1).contiguous()
                next_target_q_values =  next_target_q_values - next_alphas * next_log_pis
            q_targets = self.reward_scale*rewards_n + (1. - terminals_n) * self.discount * next_target_q_values
            q_targets = torch.clamp(q_targets, self.min_q_value, self.max_q_value)

        buffer_inputs = torch.cat([obs_n, actions_n],dim=-1)

        buffer_contexts_1 = self.cg1(buffer_inputs) # batch x num_agent x c_dim
        q1_inputs = torch.cat([actions_n, buffer_contexts_1],dim=-1)
        q1_preds = [self.qf1_n[agent](q1_inputs[:,agent,:]) for agent in range(num_agent)]
        q1_preds = torch.stack(q1_preds).transpose(0,1).contiguous()
        raw_qf1_loss = self.qf_criterion(q1_preds, q_targets)

        buffer_contexts_2 = self.cg2(buffer_inputs) # batch x num_agent x c_dim
        q2_inputs = torch.cat([actions_n, buffer_contexts_2],dim=-1)
        q2_preds = [self.qf2_n[agent](q2_inputs[:,agent,:]) for agent in range(num_agent)]
        q2_preds = torch.stack(q2_preds).transpose(0,1).contiguous()
        raw_qf2_loss = self.qf_criterion(q2_preds, q_targets)

        if self.qf_weight_decay > 0:
            reg_loss1 = self.qf_weight_decay * sum(
                torch.sum(param ** 2)
                for param in list(self.qf1.regularizable_parameters())+list(self.cg1.regularizable_parameters())
            )

            reg_loss2 = self.qf_weight_decay * sum(
                torch.sum(param ** 2)
                for param in list(self.qf2.regularizable_parameters())+list(stack.cg2.regularizable_parameters())
            )
        else:
            reg_loss1, reg_loss2 = torch.tensor(0.).to(ptu.device), torch.tensor(0.).to(ptu.device)

        qf1_loss = raw_qf1_loss + reg_loss1
        qf2_loss = raw_qf2_loss + reg_loss2

        if self._need_to_update_eval_statistics:
            self.eval_statistics['Qf1 Loss'] = ptu.get_numpy(qf1_loss)
            self.eval_statistics['Qf2 Loss'] = ptu.get_numpy(qf2_loss)
            self.eval_statistics['Raw Qf1 Loss'] = ptu.get_numpy(raw_qf1_loss)
            self.eval_statistics['Raw Qf2 Loss'] = ptu.get_numpy(raw_qf2_loss)
            self.eval_statistics['Reg Qf2 Loss'] = ptu.get_numpy(reg_loss1)
            self.eval_statistics['Reg Qf2 Loss'] = ptu.get_numpy(reg_loss2)

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        """
        Central actor operations.
        """
        buffer_inputs = torch.cat([obs_n, actions_n],dim=-1)
        buffer_ca_contexts = self.cgca(buffer_inputs)
        cactor_outputs = [self.cactor_n[agent](buffer_ca_contexts[:,agent,:],return_info=True)
                            for agent in range(num_agent)]
        # [(action, info),...]
        cactor_actions = torch.stack([cactor_outputs[agent][0] for agent in range(num_agent)]).transpose(0,1).contiguous()
        # batch x agent_num x |A|
        buffer_contexts_1 = self.cg1(buffer_inputs).detach()
        buffer_contexts_2 = self.cg2(buffer_inputs).detach()

        cactor_pre_values = torch.stack([cactor_outputs[agent][1]['preactivation'] for agent in range(num_agent)]).transpose(0,1).contiguous()
        if self.pre_activation_weight > 0:
            pre_activation_cactor_loss = (
                (cactor_pre_values**2).sum(dim=1).mean()
            )
        else:
            pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device)
        if self.use_cactor_entropy_loss:
            cactor_log_pis = torch.stack([cactor_outputs[agent][1]['log_prob'] for agent in range(num_agent)]).transpose(0,1).contiguous()
            # batch x num_agent x 1
            if self.use_automatic_entropy_tuning:
                calphas = torch.stack(self.log_calpha_n).exp()[None,:]
                calpha_loss = -(calphas * (cactor_log_pis + self.target_entropy).detach())
                calpha_loss = calpha_loss.mean()
                self.calpha_optimizer.zero_grad()
                calpha_loss.backward()
                self.calpha_optimizer.step()
                calphas = torch.stack(self.log_calpha_n).exp().detach()
            else:
                calpha_loss = torch.tensor(0.).to(ptu.device)
                calphas = torch.stack([torch.tensor(self.init_alpha).to(ptu.device) for i in range(num_agent)])
            cactor_entropy_loss = (calphas[None,:]*cactor_log_pis).mean()
        else:
            cactor_entropy_loss = torch.tensor(0.).to(ptu.device)
        
        q1_inputs = torch.cat([cactor_actions,buffer_contexts_1],dim=-1)
        q1_outputs = [self.qf1_n[agent](q1_inputs[:,agent,:]) for agent in range(num_agent)]
        q1_outputs = torch.stack(q1_outputs).transpose(0,1).contiguous()

        q2_inputs = torch.cat([cactor_actions,buffer_contexts_2],dim=-1)
        q2_outputs = [self.qf2_n[agent](q2_inputs[:,agent,:]) for agent in range(num_agent)]
        q2_outputs = torch.stack(q2_outputs).transpose(0,1).contiguous()

        q_outputs = torch.min(q1_outputs,q2_outputs)
        raw_cactor_loss = -q_outputs.mean()
        cactor_loss = (
                raw_cactor_loss +
                pre_activation_cactor_loss * self.pre_activation_weight +
                cactor_entropy_loss
        )

        if self._need_to_update_eval_statistics:
            if self.use_cactor_entropy_loss:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'CAlpha', ptu.get_numpy(calphas),
                ))
            self.eval_statistics['Cactor Loss'] = ptu.get_numpy(cactor_loss)
            self.eval_statistics['Raw Cactor Loss'] = ptu.get_numpy(raw_cactor_loss)
            self.eval_statistics['Preactivation Cactor Loss'] = ptu.get_numpy(pre_activation_cactor_loss)
            self.eval_statistics['Entropy Cactor Loss'] = ptu.get_numpy(cactor_entropy_loss)

        self.cactor_optimizer.zero_grad()
        cactor_loss.backward()
        self.cactor_optimizer.step()
                
        self._need_to_update_eval_statistics = False
        self._update_target_networks()
        self._n_train_steps_total += 1
Exemple #22
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )
        
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if self.num_qs == 1:
            q_new_actions = self.qf1(obs, new_obs_actions)
        else:
            q_new_actions = torch.min(
                self.qf1(obs, new_obs_actions),
                self.qf2(obs, new_obs_actions),
            )

        policy_loss = (alpha*log_pi - q_new_actions).mean()

        if self._current_epoch < self.policy_eval_start:
            """
            For the initial few epochs, try doing behaivoral cloning, if needed
            conventionally, there's not much difference in performance with having 20k 
            gradient steps here, or not having it
            """
            policy_log_prob = self.policy.log_prob(obs, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()
        
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        if self.num_qs > 1:
            q2_pred = self.qf2(obs, actions)
        
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if not self.max_q_backup:
            if self.num_qs == 1:
                target_q_values = self.target_qf1(next_obs, new_next_actions)
            else:
                target_q_values = torch.min(
                    self.target_qf1(next_obs, new_next_actions),
                    self.target_qf2(next_obs, new_next_actions),
                )
            
            if not self.deterministic_backup:
                target_q_values = target_q_values - alpha * new_log_pi
        
        if self.max_q_backup:
            """when using max q backup"""
            next_actions_temp, _ = self._get_policy_actions(next_obs, num_actions=10, network=self.policy)
            target_qf1_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf1).max(1)[0].view(-1, 1)
            target_qf2_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf2).max(1)[0].view(-1, 1)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()
            
        qf1_loss = self.qf_criterion(q1_pred, q_target)
        if self.num_qs > 1:
            qf2_loss = self.qf_criterion(q2_pred, q_target)

        ## add CQL
        random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1) # .cuda()
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random, network=self.policy)
        q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1)
        q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2)
        q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
        q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
        q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)
        q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)

        cat_q1 = torch.cat(
            [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
        )
        cat_q2 = torch.cat(
            [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
        )
        std_q1 = torch.std(cat_q1, dim=1)
        std_q2 = torch.std(cat_q2, dim=1)

        if self.min_q_version == 3:
            # importance sammpled version
            random_density = np.log(0.5 ** curr_actions_tensor.shape[-1])
            cat_q1 = torch.cat(
                [q1_rand - random_density, q1_next_actions - new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()], 1
            )
            cat_q2 = torch.cat(
                [q2_rand - random_density, q2_next_actions - new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()], 1
            )
            
        min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
        min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
                    
        """Subtract the log likelihood of data"""
        min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.min_q_weight
        min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.min_q_weight
        
        if self.with_lagrange:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
            min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap)
            min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap)

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss = (-min_qf1_loss - min_qf2_loss)*0.5 
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_optimizer.step()

        qf1_loss = qf1_loss + min_qf1_loss
        qf2_loss = qf2_loss + min_qf2_loss

        """
        Update networks
        """
        # Update the Q-functions iff 
        self._num_q_update_steps += 1
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.qf1_optimizer.step()

        if self.num_qs > 1:
            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            self.qf2_optimizer.step()

        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()

        """
        Soft Updates
        """
        ptu.soft_update_from_to(
            self.qf1, self.target_qf1, self.soft_target_tau
        )
        if self.num_qs > 1:
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['min QF1 Loss'] = np.mean(ptu.get_numpy(min_qf1_loss))
            if self.num_qs > 1:
                self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
                self.eval_statistics['min QF2 Loss'] = np.mean(ptu.get_numpy(min_qf2_loss))

            if not self.discrete:
                self.eval_statistics['Std QF1 values'] = np.mean(ptu.get_numpy(std_q1))
                self.eval_statistics['Std QF2 values'] = np.mean(ptu.get_numpy(std_q2))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 in-distribution values',
                    ptu.get_numpy(q1_curr_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 in-distribution values',
                    ptu.get_numpy(q2_curr_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 random values',
                    ptu.get_numpy(q1_rand),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 random values',
                    ptu.get_numpy(q2_rand),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 next_actions values',
                    ptu.get_numpy(q1_next_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 next_actions values',
                    ptu.get_numpy(q2_next_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'actions', 
                    ptu.get_numpy(actions)
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards)
                ))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            if self.num_qs > 1:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            if not self.discrete:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
            
            if self.with_lagrange:
                self.eval_statistics['Alpha_prime'] = alpha_prime.item()
                self.eval_statistics['min_q1_loss'] = ptu.get_numpy(min_qf1_loss).mean()
                self.eval_statistics['min_q2_loss'] = ptu.get_numpy(min_qf2_loss).mean()
                self.eval_statistics['threshold action gap'] = self.target_action_gap
                self.eval_statistics['alpha prime loss'] = alpha_prime_loss.item()
            
        self._n_train_steps_total += 1
Exemple #23
0
    def train_from_torch(self, batch):
        rewards_n = batch['rewards'].detach()
        terminals_n = batch['terminals'].detach()
        obs_n = batch['observations'].detach()
        actions_n = batch['actions'].detach()
        next_obs_n = batch['next_observations'].detach()

        batch_size = rewards_n.shape[0]
        num_agent = rewards_n.shape[1]
        """
        Policy operations.
        """
        online_actions_n, online_pre_values_n, online_log_pis_n = [], [], []
        for agent in range(num_agent):
            if self.shared_obs:
                policy_actions, info = self.policy_n[agent](
                    obs_n,
                    return_info=True,
                )
            else:
                policy_actions, info = self.policy_n[agent](
                    obs_n[:, agent, :],
                    return_info=True,
                )
            online_actions_n.append(policy_actions)
            online_pre_values_n.append(info['preactivation'])
            online_log_pis_n.append(info['log_prob'])
        k0_actions = torch.stack(online_actions_n)  # num_agent x batch x a_dim
        k0_actions = k0_actions.transpose(
            0, 1).contiguous()  # batch x num_agent x a_dim

        if self.shared_obs:
            k0_inputs = torch.cat(
                [obs_n, k0_actions.reshape(batch_size, -1)], dim=-1)
        else:
            k0_inputs = torch.cat([obs_n, k0_actions], dim=-1)
        k1_actions = self.cactor(
            k0_inputs, deterministic=self.deterministic_cactor_in_graph)

        if self.shared_obs:
            k1_inputs = torch.cat(
                [obs_n, k1_actions.reshape(batch_size, -1)], dim=-1)
        else:
            k1_inputs = torch.cat([obs_n, k1_actions], dim=-1)
        k1_contexts_1 = self.cg1(k1_inputs)
        k1_contexts_2 = self.cg2(k1_inputs)

        q1_inputs = torch.cat([k0_actions, k1_contexts_1], dim=-1)
        q1_outputs = self.qf1(q1_inputs)

        q2_inputs = torch.cat([k0_actions, k1_contexts_2], dim=-1)
        q2_outputs = self.qf2(q2_inputs)

        min_q_outputs = torch.min(q1_outputs,
                                  q2_outputs)  # batch x num_agent x 1

        policy_gradients_n = []
        alpha_n = []
        for agent in range(num_agent):
            policy_actions = online_actions_n[agent]
            pre_value = online_pre_values_n[agent]
            log_pi = online_log_pis_n[agent]
            if self.pre_activation_weight > 0.:
                pre_activation_policy_loss = ((pre_value**2).sum(dim=1).mean())
            else:
                pre_activation_policy_loss = torch.tensor(0.).to(ptu.device)
            if self.use_entropy_loss:
                if self.use_automatic_entropy_tuning:
                    alpha = self.log_alpha_n[agent].exp()
                    alpha_loss = -(
                        alpha *
                        (log_pi + self.target_entropy).detach()).mean()
                    self.alpha_optimizer_n[agent].zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer_n[agent].step()
                    alpha = self.log_alpha_n[agent].exp().detach()
                    alpha_n.append(alpha)
                else:
                    alpha_loss = torch.tensor(0.).to(ptu.device)
                    alpha = torch.tensor(self.init_alpha).to(ptu.device)
                    alpha_n.append(alpha)
                entropy_loss = (alpha * log_pi).mean()
            else:
                entropy_loss = torch.tensor(0.).to(ptu.device)

            raw_policy_loss = -min_q_outputs[:, agent, :].mean()
            policy_loss = (
                raw_policy_loss +
                pre_activation_policy_loss * self.pre_activation_weight +
                entropy_loss)

            policy_gradients_n.append(
                torch.autograd.grad(policy_loss,
                                    self.policy_n[agent].parameters(),
                                    retain_graph=True))

            if self._need_to_update_eval_statistics:
                self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean(
                    ptu.get_numpy(policy_loss))
                self.eval_statistics['Raw Policy Loss {}'.format(
                    agent)] = np.mean(ptu.get_numpy(raw_policy_loss))
                self.eval_statistics[
                    'Preactivation Policy Loss {}'.format(agent)] = np.mean(
                        ptu.get_numpy(pre_activation_policy_loss))
                self.eval_statistics['Entropy Loss {}'.format(
                    agent)] = np.mean(ptu.get_numpy(entropy_loss))
                if self.use_entropy_loss:
                    self.eval_statistics['Alpha {} Mean'.format(
                        agent)] = np.mean(ptu.get_numpy(alpha))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy Action {}'.format(agent),
                        ptu.get_numpy(policy_actions),
                    ))

        for agent in range(num_agent):
            # self.policy_optimizer_n[agent].zero_grad()
            for pid, p in enumerate(self.policy_n[agent].parameters()):
                p.grad = policy_gradients_n[agent][pid]
            self.policy_optimizer_n[agent].step()
        """
        Critic operations.
        """
        with torch.no_grad():
            next_actions_n, next_log_pis_n = [], []
            for agent in range(num_agent):
                if self.shared_obs:
                    next_actions, next_info = self.policy_n[agent](
                        next_obs_n,
                        return_info=True,
                        deterministic=self.deterministic_next_action,
                    )
                else:
                    next_actions, next_info = self.policy_n[agent](
                        next_obs_n[:, agent, :],
                        return_info=True,
                        deterministic=self.deterministic_next_action,
                    )
                next_actions_n.append(next_actions)
                next_log_pis_n.append(next_info['log_prob'])
            next_k0_actions = torch.stack(
                next_actions_n)  # num_agent x batch x a_dim
            next_k0_actions = next_k0_actions.transpose(
                0, 1).contiguous()  # batch x num_agent x a_dim

            if self.shared_obs:
                next_k0_inputs = torch.cat(
                    [next_obs_n,
                     next_k0_actions.reshape(batch_size, -1)],
                    dim=-1)
            else:
                next_k0_inputs = torch.cat([next_obs_n, next_k0_actions],
                                           dim=-1)
            next_k1_actions = self.cactor(
                next_k0_inputs,
                deterministic=self.deterministic_cactor_in_graph)

            if self.shared_obs:
                next_k1_inputs = torch.cat(
                    [next_obs_n,
                     next_k1_actions.reshape(batch_size, -1)],
                    dim=-1)
            else:
                next_k1_inputs = torch.cat([next_obs_n, next_k1_actions],
                                           dim=-1)
            next_k1_contexts_1 = self.target_cg1(next_k1_inputs)
            next_k1_contexts_2 = self.target_cg2(next_k1_inputs)

            next_q1_inputs = torch.cat([next_k0_actions, next_k1_contexts_1],
                                       dim=-1)
            next_target_q1_values = self.target_qf1(next_q1_inputs)
            next_q2_inputs = torch.cat([next_k0_actions, next_k1_contexts_2],
                                       dim=-1)
            next_target_q2_values = self.target_qf2(next_q2_inputs)
            next_target_q_values = torch.min(next_target_q1_values,
                                             next_target_q2_values)

            if self.use_entropy_reward:
                next_alphas = torch.stack(alpha_n)[None, :]
                next_log_pis = torch.stack(next_log_pis_n).transpose(
                    0, 1).contiguous()
                next_target_q_values = next_target_q_values - next_alphas * next_log_pis
            q_targets = self.reward_scale * rewards_n + (
                1. - terminals_n) * self.discount * next_target_q_values
            q_targets = torch.clamp(q_targets, self.min_q_value,
                                    self.max_q_value)

        if self.grad_loss:
            k0_actions = actions_n.clone().detach()
            k0_actions.requires_grad = True
        else:
            k0_actions = actions_n
        k1_actions = actions_n
        if self.shared_obs:
            buffer_inputs = torch.cat(
                [obs_n, k0_actions.reshape(batch_size, -1)], dim=-1)
        else:
            buffer_inputs = torch.cat([obs_n, k0_actions], dim=-1)

        buffer_contexts_1 = self.cg1(
            buffer_inputs)  # batch x num_agent x c_dim
        q1_inputs = torch.cat([k1_actions, buffer_contexts_1], dim=-1)
        q1_preds = self.qf1(q1_inputs)
        raw_qf1_loss = self.qf_criterion(q1_preds, q_targets)

        buffer_contexts_2 = self.cg2(
            buffer_inputs)  # batch x num_agent x c_dim
        q2_inputs = torch.cat([k1_actions, buffer_contexts_2], dim=-1)
        q2_preds = self.qf2(q2_inputs)
        raw_qf2_loss = self.qf_criterion(q2_preds, q_targets)

        if self.negative_sampling:
            batch_size, num_agent, a_dim = actions_n.shape
            perturb_agents = torch.randint(low=0,
                                           high=num_agent,
                                           size=(batch_size, )).to(ptu.device)
            neg_actions = (torch.rand(batch_size, num_agent, a_dim) * 2. -
                           1.).to(ptu.device)  # ranged in -1 to 1
            perturb_k0_actions = actions_n.clone()  # batch x agent x |A|
            perturb_k0_actions[torch.arange(batch_size),
                               perturb_agents, :] = neg_actions[
                                   torch.arange(batch_size), perturb_agents, :]
            perturb_k1_actions = neg_actions.clone()
            perturb_k1_actions[torch.arange(batch_size),
                               perturb_agents, :] = actions_n[
                                   torch.arange(batch_size), perturb_agents, :]

            if self.shared_obs:
                perturb_inputs = torch.cat(
                    [obs_n, perturb_k0_actions.reshape(batch_size, -1)],
                    dim=-1)
            else:
                perturb_inputs = torch.cat([obs_n, perturb_k0_actions], dim=-1)

            perturb_contexts_1 = self.cg1(
                perturb_inputs)  # batch x num_agent x c_dim
            perturb_q1_inputs = torch.cat(
                [perturb_k1_actions, perturb_contexts_1], dim=-1)
            perturb_q1_preds = self.qf1(perturb_q1_inputs)[
                torch.arange(batch_size), perturb_agents, :]

            perturb_contexts_2 = self.cg2(
                perturb_inputs)  # batch x num_agent x c_dim
            perturb_q2_inputs = torch.cat(
                [perturb_k1_actions, perturb_contexts_2], dim=-1)
            perturb_q2_preds = self.qf2(perturb_q2_inputs)[
                torch.arange(batch_size), perturb_agents, :]

            perturb_q_targets = q_targets[torch.arange(batch_size),
                                          perturb_agents, :]

            neg_loss1 = self.qf_criterion(perturb_q1_preds, perturb_q_targets)
            neg_loss2 = self.qf_criterion(perturb_q2_preds, perturb_q_targets)
        else:
            neg_loss1, neg_loss2 = torch.tensor(0.).to(
                ptu.device), torch.tensor(0.).to(ptu.device)

        if self.grad_loss:
            grad_loss1 = 0
            for agent in range(num_agent):
                grads = torch.autograd.grad(torch.sum(q1_preds[:, agent, :]),
                                            k0_actions,
                                            retain_graph=True,
                                            create_graph=True)[0]
                grad_loss1 += grads[:, agent, :].norm(2)
            grad_loss2 = 0
            for agent in range(num_agent):
                grads = torch.autograd.grad(torch.sum(q2_preds[:, agent, :]),
                                            k0_actions,
                                            retain_graph=True,
                                            create_graph=True)[0]
                grad_loss2 += grads[:, agent, :].norm(2)
        else:
            grad_loss1, grad_loss2 = torch.tensor(0.).to(
                ptu.device), torch.tensor(0.).to(ptu.device)

        if self.qf_weight_decay > 0:
            reg_loss1 = self.qf_weight_decay * sum(
                torch.sum(param**2)
                for param in list(self.qf1.regularizable_parameters()) +
                list(self.cg1.regularizable_parameters()))

            reg_loss2 = self.qf_weight_decay * sum(
                torch.sum(param**2)
                for param in list(self.qf2.regularizable_parameters()) +
                list(self.cg2.regularizable_parameters()))
        else:
            reg_loss1, reg_loss2 = torch.tensor(0.).to(
                ptu.device), torch.tensor(0.).to(ptu.device)

        qf1_loss = raw_qf1_loss + neg_loss1 + grad_loss1 + reg_loss1
        qf2_loss = raw_qf2_loss + neg_loss2 + grad_loss2 + reg_loss2

        if self._need_to_update_eval_statistics:
            self.eval_statistics['Qf1 Loss'] = ptu.get_numpy(qf1_loss)
            self.eval_statistics['Qf2 Loss'] = ptu.get_numpy(qf2_loss)
            self.eval_statistics['Raw Qf1 Loss'] = ptu.get_numpy(raw_qf1_loss)
            self.eval_statistics['Raw Qf2 Loss'] = ptu.get_numpy(raw_qf2_loss)
            self.eval_statistics['Neg Qf1 Loss'] = ptu.get_numpy(neg_loss1)
            self.eval_statistics['Neg Qf2 Loss'] = ptu.get_numpy(neg_loss2)
            self.eval_statistics['Grad Qf1 Loss'] = ptu.get_numpy(grad_loss1)
            self.eval_statistics['Grad Qf2 Loss'] = ptu.get_numpy(grad_loss2)
            self.eval_statistics['Reg Qf2 Loss'] = ptu.get_numpy(reg_loss1)
            self.eval_statistics['Reg Qf2 Loss'] = ptu.get_numpy(reg_loss2)

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()
        """
        Central actor operations.
        """
        if self.shared_obs:
            buffer_inputs = torch.cat(
                [obs_n, actions_n.reshape(batch_size, -1)], dim=-1)
        else:
            buffer_inputs = torch.cat([obs_n, actions_n], dim=-1)
        cactor_actions, cactor_infos = self.cactor(buffer_inputs,
                                                   return_info=True)
        # batch x agent_num x |A|
        buffer_contexts_1 = self.cg1(buffer_inputs).detach()
        buffer_contexts_2 = self.cg2(buffer_inputs).detach()

        cactor_pre_values = cactor_infos['preactivation']
        if self.pre_activation_weight > 0:
            pre_activation_cactor_loss = ((cactor_pre_values**2).sum(
                dim=1).mean())
        else:
            pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device)
        if self.use_cactor_entropy_loss:
            cactor_log_pis = cactor_infos['log_prob']  # batch x num_ageng x 1
            if self.use_automatic_entropy_tuning:
                calphas = torch.stack(self.log_calpha_n).exp()[None, :]
                calpha_loss = -(
                    calphas * (cactor_log_pis + self.target_entropy).detach())
                calpha_loss = calpha_loss.mean()
                self.calpha_optimizer.zero_grad()
                calpha_loss.backward()
                self.calpha_optimizer.step()
                calphas = torch.stack(self.log_calpha_n).exp().detach()
            else:
                calpha_loss = torch.tensor(0.).to(ptu.device)
                calphas = torch.stack([
                    torch.tensor(self.init_alpha).to(ptu.device)
                    for i in range(num_agent)
                ])
            cactor_entropy_loss = (calphas[None, :] * cactor_log_pis).mean()
        else:
            cactor_entropy_loss = torch.tensor(0.).to(ptu.device)

        q1_inputs = torch.cat([cactor_actions, buffer_contexts_1], dim=-1)
        q1_outputs = self.qf1(q1_inputs)
        q2_inputs = torch.cat([cactor_actions, buffer_contexts_2], dim=-1)
        q2_outputs = self.qf2(q2_inputs)
        q_outputs = torch.min(q1_outputs, q2_outputs)
        raw_cactor_loss = -q_outputs.mean()
        cactor_loss = (
            raw_cactor_loss +
            pre_activation_cactor_loss * self.pre_activation_weight +
            cactor_entropy_loss)

        if self._need_to_update_eval_statistics:
            if self.use_cactor_entropy_loss:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'CAlpha',
                        ptu.get_numpy(calphas),
                    ))
            self.eval_statistics['Cactor Loss'] = ptu.get_numpy(cactor_loss)
            self.eval_statistics['Raw Cactor Loss'] = ptu.get_numpy(
                raw_cactor_loss)
            self.eval_statistics['Preactivation Cactor Loss'] = ptu.get_numpy(
                pre_activation_cactor_loss)
            self.eval_statistics['Entropy Cactor Loss'] = ptu.get_numpy(
                cactor_entropy_loss)

        self.cactor_optimizer.zero_grad()
        cactor_loss.backward()
        self.cactor_optimizer.step()

        self._need_to_update_eval_statistics = False
        self._update_target_networks()
        self._n_train_steps_total += 1
Exemple #24
0
    def test_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        policy_loss = (log_pi - q_new_actions).mean()

        self.eval_statistics['validation/QF1 Loss'] = np.mean(
            ptu.get_numpy(qf1_loss))
        self.eval_statistics['validation/QF2 Loss'] = np.mean(
            ptu.get_numpy(qf2_loss))
        self.eval_statistics['validation/Policy Loss'] = np.mean(
            ptu.get_numpy(policy_loss))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q Targets',
                ptu.get_numpy(q_target),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Log Pis',
                ptu.get_numpy(log_pi),
            ))
        policy_statistics = add_prefix(dist.get_diagnostics(),
                                       "validation/policy/")
        self.eval_statistics.update(policy_statistics)
Exemple #25
0
    def train_from_torch(self, batch, train=True, pretrain=False,):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        target_vf_pred = self.vf(next_obs).detach()

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_vf_pred
        q_target = q_target.detach()
        qf1_loss = self.qf_criterion(q1_pred, q_target)
        qf2_loss = self.qf_criterion(q2_pred, q_target)

        """
        VF Loss
        """
        q_pred = torch.min(
            self.target_qf1(obs, actions),
            self.target_qf2(obs, actions),
        ).detach()
        vf_pred = self.vf(obs)
        vf_err = vf_pred - q_pred
        vf_sign = (vf_err > 0).float()
        vf_weight = (1 - vf_sign) * self.quantile + vf_sign * (1 - self.quantile)
        vf_loss = (vf_weight * (vf_err ** 2)).mean()

        """
        Policy Loss
        """
        policy_logpp = dist.log_prob(actions)

        adv = q_pred - vf_pred
        exp_adv = torch.exp(adv / self.beta)
        if self.clip_score is not None:
            exp_adv = torch.clamp(exp_adv, max=self.clip_score)

        weights = exp_adv[:, 0].detach()
        policy_loss = (-policy_logpp * weights).mean()

        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

            self.vf_optimizer.zero_grad()
            vf_loss.backward()
            self.vf_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'rewards',
                ptu.get_numpy(rewards),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'terminals',
                ptu.get_numpy(terminals),
            ))
            self.eval_statistics['replay_buffer_len'] = self.replay_buffer._size
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(create_stats_ordered_dict(
                'Advantage Weights',
                ptu.get_numpy(weights),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Advantage Score',
                ptu.get_numpy(adv),
            ))

            self.eval_statistics.update(create_stats_ordered_dict(
                'V1 Predictions',
                ptu.get_numpy(vf_pred),
            ))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))

        self._n_train_steps_total += 1
Exemple #26
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(
            noise,
            -self.target_policy_noise_clip,
            self.target_policy_noise_clip
        )
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target) ** 2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target) ** 2
        qf2_loss = bellman_errors_2.mean()

        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = - q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = - q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 1',
                ptu.get_numpy(bellman_errors_1),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 2',
                ptu.get_numpy(bellman_errors_2),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy Action',
                ptu.get_numpy(policy_actions),
            ))
        self._n_train_steps_total += 1
Exemple #27
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy and Alpha Loss
        """
        pis = self.policy(obs)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(pis.detach() * self.log_alpha.exp() *
                           (torch.log(pis + 1e-3) +
                            self.target_entropy).detach()).sum(-1).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        min_q = torch.min(self.qf1(obs), self.qf2(obs)).detach()
        policy_loss = (pis *
                       (alpha * torch.log(pis + 1e-3) - min_q)).sum(-1).mean()
        """
        QF Loss
        """
        new_pis = self.policy(next_obs).detach()
        target_min_q_values = torch.min(
            self.target_qf1(next_obs),
            self.target_qf2(next_obs),
        )
        target_q_values = (
            new_pis *
            (target_min_q_values - alpha * torch.log(new_pis + 1e-3))).sum(
                -1, keepdim=True)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values

        q1_pred = torch.sum(self.qf1(obs) * actions.detach(),
                            dim=-1,
                            keepdim=True)
        q2_pred = torch.sum(self.qf2(obs) * actions.detach(),
                            dim=-1,
                            keepdim=True)
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Pis',
                    ptu.get_numpy(pis),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
Exemple #28
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs,
            reparameterize=True,
            return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
Exemple #29
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Behavior clone a policy
        """
        recon, mean, std = self.vae(obs, actions)
        recon_loss = self.qf_criterion(recon, actions)
        kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                          std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * kl_loss

        self.vae_optimizer.zero_grad()
        vae_loss.backward()
        self.vae_optimizer.step()
        """
        Critic Training
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            # Duplicate state 10 times (10 is a hyperparameter chosen by BCQ)
            state_rep = next_obs.unsqueeze(1).repeat(1, 10, 1).view(
                next_obs.shape[0] * 10, next_obs.shape[1])  # 10BxS
            # Compute value of perturbed actions sampled from the VAE
            action_rep = self.policy(state_rep)[0]
            target_qf1 = self.target_qf1(state_rep, action_rep)
            target_qf2 = self.target_qf2(state_rep, action_rep)

            # Soft Clipped Double Q-learning
            target_Q = 0.75 * torch.min(target_qf1,
                                        target_qf2) + 0.25 * torch.max(
                                            target_qf1, target_qf2)
            target_Q = target_Q.view(next_obs.shape[0],
                                     -1).max(1)[0].view(-1, 1)
            target_Q = self.reward_scale * rewards + (
                1.0 - terminals) * self.discount * target_Q  # Bx1

        qf1_pred = self.qf1(obs, actions)  # Bx1
        qf2_pred = self.qf2(obs, actions)  # Bx1
        qf1_loss = (qf1_pred - target_Q.detach()).pow(2).mean()
        qf2_loss = (qf2_pred - target_Q.detach()).pow(2).mean()
        """
        Actor Training
        """
        sampled_actions, raw_sampled_actions = self.vae.decode_multiple(
            obs, num_decode=self.num_samples_mmd_match)
        actor_samples, _, _, _, _, _, _, raw_actor_actions = self.policy(
            obs.unsqueeze(1).repeat(1, self.num_samples_mmd_match,
                                    1).view(-1, obs.shape[1]),
            return_log_prob=True)
        actor_samples = actor_samples.view(obs.shape[0],
                                           self.num_samples_mmd_match,
                                           actions.shape[1])
        raw_actor_actions = raw_actor_actions.view(obs.shape[0],
                                                   self.num_samples_mmd_match,
                                                   actions.shape[1])

        if self.kernel_choice == 'laplacian':
            mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions,
                                               raw_actor_actions,
                                               sigma=self.mmd_sigma)
        elif self.kernel_choice == 'gaussian':
            mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions,
                                              raw_actor_actions,
                                              sigma=self.mmd_sigma)

        action_divergence = ((sampled_actions - actor_samples)**2).sum(-1)
        raw_action_divergence = ((raw_sampled_actions -
                                  raw_actor_actions)**2).sum(-1)

        q_val1 = self.qf1(obs, actor_samples[:, 0, :])
        q_val2 = self.qf2(obs, actor_samples[:, 0, :])

        if self.policy_update_style == '0':
            policy_loss = torch.min(q_val1, q_val2)[:, 0]
        elif self.policy_update_style == '1':
            policy_loss = torch.mean(q_val1, q_val2)[:, 0]

        if self._n_train_steps_total >= 40000:
            # Now we can update the policy
            if self.mode == 'auto':
                policy_loss = (-policy_loss + self.log_alpha.exp() *
                               (mmd_loss - self.target_mmd_thresh)).mean()
            else:
                policy_loss = (-policy_loss + 100 * mmd_loss).mean()
        else:
            if self.mode == 'auto':
                policy_loss = (self.log_alpha.exp() *
                               (mmd_loss - self.target_mmd_thresh)).mean()
            else:
                policy_loss = 100 * mmd_loss.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        if self.mode == 'auto':
            policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()

        if self.mode == 'auto':
            self.alpha_optimizer.zero_grad()
            (-policy_loss).backward()
            self.alpha_optimizer.step()
            self.log_alpha.data.clamp_(min=-5.0, max=10.0)
        """
        Update networks
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics[
                'Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(qf1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(qf2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(target_Q),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict('MMD Loss', ptu.get_numpy(mmd_loss)))
            self.eval_statistics.update(
                create_stats_ordered_dict('Action Divergence',
                                          ptu.get_numpy(action_divergence)))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Raw Action Divergence',
                    ptu.get_numpy(raw_action_divergence)))
            if self.mode == 'auto':
                self.eval_statistics['Alpha'] = self.log_alpha.exp().item()

        self._n_train_steps_total += 1
Exemple #30
0
    def train_from_torch(
        self,
        batch,
        train=True,
        pretrain=False,
    ):
        """

        :param batch:
        :param train:
        :param pretrain:
        :return:
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.brac:
            buf_dist = self.buffer_policy(obs)
            buf_log_pi = buf_dist.log_prob(actions)
            rewards = rewards + buf_log_pi

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v1_pi = self.qf1(obs, policy_mle)
            v2_pi = self.qf2(obs, policy_mle)
            v_pi = torch.min(v1_pi, v2_pi)
        else:
            if self.vf_K > 1:
                vs = []
                for i in range(self.vf_K):
                    u = dist.sample()
                    q1 = self.qf1(obs, u)
                    q2 = self.qf2(obs, u)
                    v = torch.min(q1, q2)
                    # v = q1
                    vs.append(v)
                v_pi = torch.cat(vs, 1).mean(dim=1)
            else:
                # v_pi = self.qf1(obs, new_obs_actions)
                v1_pi = self.qf1(obs, new_obs_actions)
                v2_pi = self.qf2(obs, new_obs_actions)
                v_pi = torch.min(v1_pi, v2_pi)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        elif self.buffer_policy_sample_actions:
            buf_dist = self.buffer_policy(obs)
            u, _ = buf_dist.rsample_and_logprob()
            qf1_buffer_actions = self.qf1(obs, u)
            qf2_buffer_actions = self.qf2(obs, u)
            q_buffer_actions = torch.min(
                qf1_buffer_actions,
                qf2_buffer_actions,
            )
            if self.awr_min_q:
                q_adv = q_buffer_actions
            else:
                q_adv = qf1_buffer_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)

        if self.normalize_over_state == "advantage":
            score = q_adv - v_pi
            if self.mask_positive_advantage:
                score = torch.sign(score)
        elif self.normalize_over_state == "Z":
            buffer_dist = self.buffer_policy(obs)
            K = self.Z_K
            buffer_obs = []
            buffer_actions = []
            log_bs = []
            log_pis = []
            for i in range(K):
                u = buffer_dist.sample()
                log_b = buffer_dist.log_prob(u)
                log_pi = dist.log_prob(u)
                buffer_obs.append(obs)
                buffer_actions.append(u)
                log_bs.append(log_b)
                log_pis.append(log_pi)
            buffer_obs = torch.cat(buffer_obs, 0)
            buffer_actions = torch.cat(buffer_actions, 0)
            p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, ))
            log_pi = torch.cat(log_pis, 0)
            log_pi = log_pi.sum(dim=1, )
            q1_b = self.qf1(buffer_obs, buffer_actions)
            q2_b = self.qf2(buffer_obs, buffer_actions)
            q_b = torch.min(q1_b, q2_b)
            q_b = torch.reshape(q_b, (-1, K))
            adv_b = q_b - v_pi
            # if self._n_train_steps_total % 100 == 0:
            #     import ipdb; ipdb.set_trace()
            # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True)
            # score = torch.exp((q_adv - v_pi) / beta) / Z
            # score = score / sum(score)
            logK = torch.log(ptu.tensor(float(K)))
            logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True)
            logS = (q_adv - v_pi) / beta - logZ
            # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True)
            # logS = q_adv/beta - logZ
            score = F.softmax(logS, dim=0)  # score / sum(score)
        else:
            error

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        if self.weight_loss and weights is None:
            if self.normalize_over_batch:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif not self.normalize_over_batch:
                weights = score
            else:
                error
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0:
            del self.buffer_policy_optimizer
            self.buffer_policy_optimizer = self.optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=self.policy_weight_decay,
                lr=self.policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            for i in range(self.num_buffer_policy_train_steps_on_reset):
                if self.train_bc_on_rl_buffer:
                    if self.advantage_weighted_buffer_loss:
                        buffer_dist = self.buffer_policy(obs)
                        buffer_u = actions
                        buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob(
                        )
                        buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                        buffer_policy_logpp = buffer_policy_logpp[:, None]

                        buffer_q1_pred = self.qf1(obs, buffer_u)
                        buffer_q2_pred = self.qf2(obs, buffer_u)
                        buffer_q_adv = torch.min(buffer_q1_pred,
                                                 buffer_q2_pred)

                        buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                        buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                        buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                        buffer_score = buffer_q_adv - buffer_v_pi
                        buffer_weights = F.softmax(buffer_score / beta, dim=0)
                        buffer_policy_loss = self.awr_weight * (
                            -buffer_policy_logpp * len(buffer_weights) *
                            buffer_weights.detach()).mean()
                    else:
                        buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                            self.replay_buffer.train_replay_buffer,
                            self.buffer_policy)

                    self.buffer_policy_optimizer.zero_grad()
                    buffer_policy_loss.backward(retain_graph=True)
                    self.buffer_policy_optimizer.step()

        if self.train_bc_on_rl_buffer:
            if self.advantage_weighted_buffer_loss:
                buffer_dist = self.buffer_policy(obs)
                buffer_u = actions
                buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob()
                buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                buffer_policy_logpp = buffer_policy_logpp[:, None]

                buffer_q1_pred = self.qf1(obs, buffer_u)
                buffer_q2_pred = self.qf2(obs, buffer_u)
                buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred)

                buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                buffer_score = buffer_q_adv - buffer_v_pi
                buffer_weights = F.softmax(buffer_score / beta, dim=0)
                buffer_policy_loss = self.awr_weight * (
                    -buffer_policy_logpp * len(buffer_weights) *
                    buffer_weights.detach()).mean()
            else:
                buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0:
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))

            if self.normalize_over_state == "Z":
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'logZ',
                        ptu.get_numpy(logZ),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                _, buffer_train_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)

                _, buffer_test_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.validation_replay_buffer,
                    self.buffer_policy)
                buffer_dist = self.buffer_policy(obs)
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                _, train_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_train_buffer, self.buffer_policy)

                _, test_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.buffer_policy)

                self.eval_statistics.update({
                    "buffer_policy/Train Online Logprob":
                    -1 * ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Online Logprob":
                    -1 * ptu.get_numpy(buffer_test_logp_loss),
                    "buffer_policy/Train Offline Logprob":
                    -1 * ptu.get_numpy(train_offline_logp_loss),
                    "buffer_policy/Test Offline Logprob":
                    -1 * ptu.get_numpy(test_offline_logp_loss),
                    "buffer_policy/train_policy_loss":
                    ptu.get_numpy(buffer_policy_loss),
                    # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss),
                    "buffer_policy/kl_div":
                    ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

            if self.validation_qlearning:
                train_data = self.replay_buffer.validation_replay_buffer.random_batch(
                    self.bc_batch_size)
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                # goals = train_data['resampled_goals']
                train_data[
                    'observations'] = obs  # torch.cat((obs, goals), dim=1)
                train_data[
                    'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
                self.test_from_torch(train_data)

        self._n_train_steps_total += 1