Exemple #1
0
        def get_mask_diagnostics(unused):
            from rlkit.core.logging import append_log, add_prefix, OrderedDict
            log = OrderedDict()
            for prefix, collector in zip(addl_log_prefixes, addl_collectors):
                paths = collector.collect_new_paths(
                    max_path_length,
                    variant['algo_kwargs']['num_eval_steps_per_epoch'],
                    discard_incomplete_paths=True,
                )
                old_path_info = eval_env.get_diagnostics(paths)

                keys_to_keep = []
                for key in old_path_info.keys():
                    if ('env_infos' in key) and ('final' in key) and ('Mean' in key):
                        keys_to_keep.append(key)
                path_info = OrderedDict()
                for key in keys_to_keep:
                    path_info[key] = old_path_info[key]

                generic_info = add_prefix(
                    path_info,
                    prefix,
                )
                append_log(log, generic_info)

            for collector in addl_collectors:
                collector.end_epoch(0)
            return log
Exemple #2
0
 def get_eval_diagnostics(key_to_paths):
     stats = OrderedDict()
     for eval_env_name, paths in key_to_paths.items():
         env, _ = eval_env_name_to_env_and_context_distrib[eval_env_name]
         stats.update(
             add_prefix(
                 env.get_diagnostics(paths),
                 eval_env_name,
                 divider='/',
             ))
         stats.update(
             add_prefix(
                 eval_util.get_generic_path_information(paths),
                 eval_env_name,
                 divider='/',
             ))
     return stats
Exemple #3
0
 def train(self):
     # first train only the Q function
     iteration = 0
     for i in range(self.num_batches):
         train_data = self.replay_buffer.random_batch(self.batch_size)
         train_data = np_to_pytorch_batch(train_data)
         obs = train_data['observations']
         next_obs = train_data['next_observations']
         train_data['observations'] = obs
         train_data['next_observations'] = next_obs
         self.trainer.train_from_torch(train_data)
         if i % self.logging_period == 0:
             stats_with_prefix = add_prefix(
                 self.trainer.eval_statistics, prefix="trainer/")
             self.trainer.end_epoch(iteration)
             iteration += 1
             logger.record_dict(stats_with_prefix)
             logger.dump_tabular(with_prefix=True, with_timestamp=False)
Exemple #4
0
    def get_mask_diagnostics(unused):
        from rlkit.core.logging import append_log, add_prefix, OrderedDict
        from rlkit.misc import eval_util
        log = OrderedDict()
        for prefix, collector in zip(log_prefixes, collectors):
            paths = collector.collect_new_paths(
                max_path_length,
                masking_eval_steps,
                discard_incomplete_paths=True,
            )
            generic_info = add_prefix(
                eval_util.get_generic_path_information(paths),
                prefix,
            )
            append_log(log, generic_info)

        for collector in collectors:
            collector.end_epoch(0)
        return log
Exemple #5
0
    def pretrain_q_with_bc_data(self, batch_size):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        prev_time = time.time()
        for i in range(self.num_pretrain_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()
Exemple #6
0
    def test_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        policy_loss = (log_pi - q_new_actions).mean()

        self.eval_statistics['validation/QF1 Loss'] = np.mean(
            ptu.get_numpy(qf1_loss))
        self.eval_statistics['validation/QF2 Loss'] = np.mean(
            ptu.get_numpy(qf2_loss))
        self.eval_statistics['validation/Policy Loss'] = np.mean(
            ptu.get_numpy(policy_loss))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q Targets',
                ptu.get_numpy(q_target),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Log Pis',
                ptu.get_numpy(log_pi),
            ))
        policy_statistics = add_prefix(dist.get_diagnostics(),
                                       "validation/policy/")
        self.eval_statistics.update(policy_statistics)
Exemple #7
0
    def pretrain_q_with_bc_data(self):
        """

        :return:
        """
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain1_steps):
            self.eval_statistics = dict()

            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)
            if i % self.pretraining_logging_period == 0:
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.q_num_pretrain2_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()

        if self.post_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_pretrain_hyperparams)
Exemple #8
0
    def train_from_torch(self, batch, train=True, pretrain=False,):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        target_vf_pred = self.vf(next_obs).detach()

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_vf_pred
        q_target = q_target.detach()
        qf1_loss = self.qf_criterion(q1_pred, q_target)
        qf2_loss = self.qf_criterion(q2_pred, q_target)

        """
        VF Loss
        """
        q_pred = torch.min(
            self.target_qf1(obs, actions),
            self.target_qf2(obs, actions),
        ).detach()
        vf_pred = self.vf(obs)
        vf_err = vf_pred - q_pred
        vf_sign = (vf_err > 0).float()
        vf_weight = (1 - vf_sign) * self.quantile + vf_sign * (1 - self.quantile)
        vf_loss = (vf_weight * (vf_err ** 2)).mean()

        """
        Policy Loss
        """
        policy_logpp = dist.log_prob(actions)

        adv = q_pred - vf_pred
        exp_adv = torch.exp(adv / self.beta)
        if self.clip_score is not None:
            exp_adv = torch.clamp(exp_adv, max=self.clip_score)

        weights = exp_adv[:, 0].detach()
        policy_loss = (-policy_logpp * weights).mean()

        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

            self.vf_optimizer.zero_grad()
            vf_loss.backward()
            self.vf_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'rewards',
                ptu.get_numpy(rewards),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'terminals',
                ptu.get_numpy(terminals),
            ))
            self.eval_statistics['replay_buffer_len'] = self.replay_buffer._size
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(create_stats_ordered_dict(
                'Advantage Weights',
                ptu.get_numpy(weights),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Advantage Score',
                ptu.get_numpy(adv),
            ))

            self.eval_statistics.update(create_stats_ordered_dict(
                'V1 Predictions',
                ptu.get_numpy(vf_pred),
            ))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))

        self._n_train_steps_total += 1
Exemple #9
0
    def compute_loss(
        self,
        batch,
        skip_statistics=False,
    ) -> Tuple[SACLosses, LossStatistics]:
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        collisions = batch['collision']
        risk = batch['risk']
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(-1)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        ## TODO: finetune the loss function BELOW
        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )

        # r_new_actions = torch.min(
        #     self.rf1(obs, new_obs_actions),
        #     self.rf2(obs, new_obs_actions),
        # )
        r_new_actions = self.rf1(obs, new_obs_actions)
        r_bound = self.delta
        r_loss_coeff = self.risk_coeff
        r_diff = r_new_actions - r_bound
        m = nn.Hardtanh(0, 1)
        r_policy_loss = m(r_diff) * r_loss_coeff
        # TODO(cyrushx): Add risk in policy loss.
        # policy_loss = (alpha*log_pi - q_new_actions + 1.*r_new_actions).mean()
        policy_loss = (alpha * log_pi - q_new_actions + r_policy_loss).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.unsqueeze(-1)
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Risk Critic Loss
        """
        r1_pred = self.rf1(obs, actions)
        r2_pred = self.rf2(obs, actions)
        # TODO(cyrushx): Replace target_r_values with ground truth risk values.
        target_r_values = self.target_rf1(next_obs, new_next_actions)
        # target_r_values = torch.min(
        #     self.target_rf1(next_obs, new_next_actions),
        #     self.target_rf2(next_obs, new_next_actions),
        # ) - alpha * new_log_pi

        r_target = risk
        rf1_loss = self.rf_criterion(r1_pred, r_target.detach())
        rf2_loss = self.rf_criterion(r2_pred, r_target.detach())
        """
        Save some statistics for eval
        """
        eval_statistics = OrderedDict()
        if not skip_statistics:
            eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            eval_statistics['RF1 Loss'] = np.mean(ptu.get_numpy(rf1_loss))
            eval_statistics['RF2 Loss'] = np.mean(ptu.get_numpy(rf2_loss))
            eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'R1 Predictions',
                    ptu.get_numpy(r1_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'R2 Predictions',
                    ptu.get_numpy(r2_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'R Targets',
                    ptu.get_numpy(r_target),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Collisions',
                    ptu.get_numpy(collisions),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            eval_statistics.update(policy_statistics)
            if self.use_automatic_entropy_tuning:
                eval_statistics['Alpha'] = alpha.item()
                eval_statistics['Alpha Loss'] = alpha_loss.item()

        loss = SACLosses(
            policy_loss=policy_loss,
            qf1_loss=qf1_loss,
            qf2_loss=qf2_loss,
            alpha_loss=alpha_loss,
            rf1_loss=rf1_loss,
            rf2_loss=rf2_loss,
        )

        return loss, eval_statistics
Exemple #10
0
    def compute_loss(
        self,
        batch,
        skip_statistics=False,
    ) -> Tuple[SACLosses, LossStatistics]:
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        collisions = batch['collision']
        import IPython; IPython.embed()
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(-1)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        ## TODO: finetune the loss function BELOW
        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        r_new_actions = self.rf1(obs, new_obs_actions)
        # r_new_actions = torch.min(
        #     self.rf1(obs, new_obs_actions),
        #     self.rf2(obs, new_obs_actions),
        # )

        # Compute an approximate step function, such that f(risk <= r_bound) = 1; f(risk > r_bound) = 0.
        r_bound = 0.3
        r_left = r_bound - r_new_actions
        m = nn.Hardtanh(0, 0.01)
        r_step = m(r_left) * 100
        policy_loss = (alpha * log_pi - q_new_actions * r_step).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.unsqueeze(-1)
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        Risk Critic Loss
        """
        r1_pred = self.rf1(obs, actions)
        r2_pred = self.rf2(obs, actions)
        target_r_values = self.target_rf1(next_obs, new_next_actions)
        # target_r_values = torch.min(
        #     self.target_rf1(next_obs, new_next_actions),
        #     self.target_rf2(next_obs, new_next_actions),
        # ) - alpha * new_log_pi

        r_target = collisions + (1. - terminals) * (1 - collisions) * target_r_values
        rf1_loss = self.qf_criterion(r1_pred, r_target.detach())
        rf2_loss = self.qf_criterion(r2_pred, r_target.detach())

        """
        Save some statistics for eval
        """
        eval_statistics = OrderedDict()
        if not skip_statistics:
            eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            eval_statistics['RF1 Loss'] = np.mean(ptu.get_numpy(rf1_loss))
            eval_statistics['RF2 Loss'] = np.mean(ptu.get_numpy(rf2_loss))
            eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            eval_statistics.update(create_stats_ordered_dict(
                'R1 Predictions',
                ptu.get_numpy(r1_pred),
            ))
            eval_statistics.update(create_stats_ordered_dict(
                'R2 Predictions',
                ptu.get_numpy(r2_pred),
            ))
            eval_statistics.update(create_stats_ordered_dict(
                'R Targets',
                ptu.get_numpy(r_target),
            ))
            eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            eval_statistics.update(policy_statistics)
            if self.use_automatic_entropy_tuning:
                eval_statistics['Alpha'] = alpha.item()
                eval_statistics['Alpha Loss'] = alpha_loss.item()

        loss = SACLosses(
            policy_loss=policy_loss,
            qf1_loss=qf1_loss,
            qf2_loss=qf2_loss,
            alpha_loss=alpha_loss,
            rf1_loss=rf1_loss,
            rf2_loss=rf2_loss,
        )

        return loss, eval_statistics
 def get_snapshot(self):
     snapshot = {}
     for name, collector in self.path_collectors.items():
         snapshot.update(
             add_prefix(collector.get_snapshot(), name, divider='/'), )
     return snapshot
Exemple #12
0
    def compute_loss(
        self,
        batch,
        return_statistics=False,
    ) -> Union[PGRLosses, Tuple[PGRLosses, MutableMapping]]:
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        eval_statistics = OrderedDict()
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(-1)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
        else:
            alpha_loss = 0

        alpha = self.get_alpha()

        if not self._qfs_were_initialized and self._auto_init_qf_bias:
            average_value = (rewards - alpha * log_pi).mean()
            self.qf1.last_fc.bias.data = average_value
            self.qf2.last_fc.bias.data = average_value
            self._qfs_were_initialized = False

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        """
        QF Loss
        """
        bootstrap_value, q1_pred, q2_pred, bootstrap_log_pi_term = (
            self.get_bootstrap_stats(
                obs,
                actions,
                next_obs,
            ))
        # Use the unscaled bootstrap values/rewards so that the weight on the
        # the Q-value/reward has the correct scale relative to the other terms
        raw_discount = self.get_discount_factor(
            bootstrap_value,
            rewards,
            obs,
            actions,
        )
        discount = (self._weight_on_prior_discount * self.discount +
                    (1 - self._weight_on_prior_discount) * raw_discount)
        q_target = self._compute_target_q_value(
            discount,
            rewards,
            terminals,
            bootstrap_value,
            eval_statistics,
            return_statistics,
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Save some statistics for eval
        """
        if return_statistics:
            eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'bootstrap log pi',
                    ptu.get_numpy(bootstrap_log_pi_term),
                ))
            if isinstance(discount, torch.Tensor):
                eval_statistics.update(
                    create_stats_ordered_dict(
                        'discount factor',
                        ptu.get_numpy(raw_discount),
                    ))
            else:
                eval_statistics.update(
                    create_stats_ordered_dict(
                        'discount factor',
                        np.array([raw_discount]),
                    ))
            if isinstance(discount, torch.Tensor):
                eval_statistics.update(
                    create_stats_ordered_dict(
                        'used discount factor',
                        ptu.get_numpy(discount),
                    ))
            else:
                eval_statistics.update(
                    create_stats_ordered_dict(
                        'used discount factor',
                        np.array([discount]),
                    ))
            eval_statistics[
                'weight on prior discount'] = self._weight_on_prior_discount
            reward_scale = self.reward_scale
            if isinstance(reward_scale, torch.Tensor):
                reward_scale = ptu.get_numpy(reward_scale)
            eval_statistics['reward scale'] = reward_scale
            eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            eval_statistics['Policy Q-only Loss'] = np.mean(
                ptu.get_numpy(-q_new_actions))
            eval_statistics['Policy entropy-only Loss'] = np.mean(
                ptu.get_numpy(alpha * log_pi))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            eval_statistics.update(policy_statistics)
            if self.use_automatic_entropy_tuning:
                eval_statistics['Alpha'] = alpha.item()
                eval_statistics['Alpha Loss'] = alpha_loss.item()

        losses = PGRLosses(
            policy_loss=policy_loss,
            qf1_loss=qf1_loss,
            qf2_loss=qf2_loss,
            alpha_loss=alpha_loss,
        )
        if return_statistics:
            return losses, eval_statistics
        else:
            return losses
Exemple #13
0
 def _get_vae_diagnostics(self):
     return add_prefix(
         self.model_trainer.get_diagnostics(),
         prefix='vae_trainer/',
     )
Exemple #14
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        context = batch['context']

        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist, p_z, task_z_with_grad = self.agent(
            obs,
            context,
            return_latent_posterior_and_task_z=True,
        )
        task_z_detached = task_z_with_grad.detach()
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(1)
        next_dist = self.agent(next_obs, context)

        if self._debug_ignore_context:
            task_z_with_grad = task_z_with_grad * 0

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        unscaled_rewards_flat = rewards.view(t * b, 1)
        rewards_flat = unscaled_rewards_flat * self.reward_scale
        terms_flat = terminals.view(t * b, 1)

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        if self.backprop_q_loss_into_encoder:
            q1_pred = self.qf1(obs, actions, task_z_with_grad)
            q2_pred = self.qf2(obs, actions, task_z_with_grad)
        else:
            q1_pred = self.qf1(obs, actions, task_z_detached)
            q2_pred = self.qf2(obs, actions, task_z_detached)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.unsqueeze(1)
        with torch.no_grad():
            target_q_values = torch.min(
                self.target_qf1(next_obs, new_next_actions, task_z_detached),
                self.target_qf2(next_obs, new_next_actions, task_z_detached),
            ) - alpha * new_log_pi

        q_target = rewards_flat + (
            1. - terms_flat) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Context Encoder Loss
        """
        if self._debug_use_ground_truth_context:
            kl_div = kl_loss = ptu.zeros(0)
        else:
            kl_div = kl_divergence(p_z,
                                   self.agent.latent_prior).mean(dim=0).sum()
            kl_loss = self.kl_lambda * kl_div

        if self.train_context_decoder:
            # TODO: change to use a distribution
            reward_pred = self.context_decoder(obs, actions, task_z_with_grad)
            reward_prediction_loss = ((reward_pred -
                                       unscaled_rewards_flat)**2).mean()
            context_loss = kl_loss + reward_prediction_loss
        else:
            context_loss = kl_loss
            reward_prediction_loss = ptu.zeros(1)
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions, task_z_detached)
        qf2_new_actions = self.qf2(obs, new_obs_actions, task_z_detached)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.vf_K > 1:
            vs = []
            for i in range(self.vf_K):
                u = dist.sample()
                q1 = self.qf1(obs, u, task_z_detached)
                q2 = self.qf2(obs, u, task_z_detached)
                v = torch.min(q1, q2)
                # v = q1
                vs.append(v)
            v_pi = torch.cat(vs, 1).mean(dim=1)
        else:
            # v_pi = self.qf1(obs, new_obs_actions)
            v1_pi = self.qf1(obs, new_obs_actions, task_z_detached)
            v2_pi = self.qf2(obs, new_obs_actions, task_z_detached)
            v_pi = torch.min(v1_pi, v2_pi)

        u = actions
        if self.awr_min_q:
            q_adv = torch.min(q1_pred, q2_pred)
        else:
            q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)
            beta_loss = ptu.zeros(1)

        score = q_adv - v_pi
        if self.mask_positive_advantage:
            score = torch.sign(score)

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        weights = batch.get('weights', None)
        if self.weight_loss and weights is None:
            if self.normalize_over_batch == True:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif self.normalize_over_batch == False:
                weights = score
            elif self.normalize_over_batch == 'uniform':
                weights = F.softmax(ptu.ones_like(score) / beta, dim=0)
            else:
                raise ValueError(self.normalize_over_batch)
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.train_reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            if self.train_encoder_decoder:
                self.context_optimizer.zero_grad()
            if self.train_agent:
                self.qf1_optimizer.zero_grad()
                self.qf2_optimizer.zero_grad()
            context_loss.backward(retain_graph=True)
            # retain graph because the encoder is trained by both QF losses
            qf1_loss.backward(retain_graph=True)
            qf2_loss.backward()
            if self.train_agent:
                self.qf1_optimizer.step()
                self.qf2_optimizer.step()
            if self.train_encoder_decoder:
                self.context_optimizer.step()

        if self.train_agent:
            if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()
        self._num_gradient_steps += 1
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics['task_embedding/kl_divergence'] = (
                ptu.get_numpy(kl_div))
            self.eval_statistics['task_embedding/kl_loss'] = (
                ptu.get_numpy(kl_loss))
            self.eval_statistics['task_embedding/reward_prediction_loss'] = (
                ptu.get_numpy(reward_prediction_loss))
            self.eval_statistics['task_embedding/context_loss'] = (
                ptu.get_numpy(context_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))
            self.eval_statistics['reparam_weight'] = self.train_reparam_weight
            self.eval_statistics['num_gradient_steps'] = (
                self._num_gradient_steps)

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

        self._n_train_steps_total += 1
Exemple #15
0
    def train(self):
        # first train only the Q function
        iteration = 0
        timer.return_global_times = True
        timer.reset()
        for i in range(self.num_batches):
            if self.use_meta_learning_buffer:
                train_data = self.meta_replay_buffer.sample_meta_batch(
                    rl_batch_size=self.batch_size,
                    meta_batch_size=self.meta_batch_size,
                    embedding_batch_size=self.task_embedding_batch_size,
                )
                train_data = np_to_pytorch_batch(train_data)
            else:
                task_indices = np.random.choice(
                    self.train_tasks, self.meta_batch_size,
                )
                train_data = self.replay_buffer.sample_batch(
                    task_indices,
                    self.batch_size,
                )
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                train_data['observations'] = obs
                train_data['next_observations'] = next_obs
                train_data['context'] = (
                    self.task_embedding_replay_buffer.sample_context(
                        task_indices,
                        self.task_embedding_batch_size,
                    ))
            timer.start_timer('train', unique=False)
            self.trainer.train_from_torch(train_data)
            timer.stop_timer('train')
            if i % self.logging_period == 0 or i == self.num_batches - 1:
                stats_with_prefix = add_prefix(
                    self.trainer.eval_statistics, prefix="trainer/")
                self.trainer.end_epoch(iteration)
                logger.record_dict(stats_with_prefix)
                timer.start_timer('extra_fns', unique=False)
                for fn in self._extra_eval_fns:
                    extra_stats = fn()
                    logger.record_dict(extra_stats)
                timer.stop_timer('extra_fns')


                # TODO: evaluate during offline RL
                # eval_stats = self.get_eval_statistics()
                # eval_stats_with_prefix = add_prefix(eval_stats, prefix="eval/")
                # logger.record_dict(eval_stats_with_prefix)

                logger.record_tabular('iteration', iteration)
                logger.record_dict(_get_epoch_timings())
                try:
                    import os
                    import psutil
                    process = psutil.Process(os.getpid())
                    logger.record_tabular('RAM Usage (Mb)', int(process.memory_info().rss / 1000000))
                except ImportError:
                    pass
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                iteration += 1
Exemple #16
0
    def train_from_torch(
        self,
        batch,
        train=True,
        pretrain=False,
    ):
        """

        :param batch:
        :param train:
        :param pretrain:
        :return:
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.brac:
            buf_dist = self.buffer_policy(obs)
            buf_log_pi = buf_dist.log_prob(actions)
            rewards = rewards + buf_log_pi

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v1_pi = self.qf1(obs, policy_mle)
            v2_pi = self.qf2(obs, policy_mle)
            v_pi = torch.min(v1_pi, v2_pi)
        else:
            if self.vf_K > 1:
                vs = []
                for i in range(self.vf_K):
                    u = dist.sample()
                    q1 = self.qf1(obs, u)
                    q2 = self.qf2(obs, u)
                    v = torch.min(q1, q2)
                    # v = q1
                    vs.append(v)
                v_pi = torch.cat(vs, 1).mean(dim=1)
            else:
                # v_pi = self.qf1(obs, new_obs_actions)
                v1_pi = self.qf1(obs, new_obs_actions)
                v2_pi = self.qf2(obs, new_obs_actions)
                v_pi = torch.min(v1_pi, v2_pi)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        elif self.buffer_policy_sample_actions:
            buf_dist = self.buffer_policy(obs)
            u, _ = buf_dist.rsample_and_logprob()
            qf1_buffer_actions = self.qf1(obs, u)
            qf2_buffer_actions = self.qf2(obs, u)
            q_buffer_actions = torch.min(
                qf1_buffer_actions,
                qf2_buffer_actions,
            )
            if self.awr_min_q:
                q_adv = q_buffer_actions
            else:
                q_adv = qf1_buffer_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)

        if self.normalize_over_state == "advantage":
            score = q_adv - v_pi
            if self.mask_positive_advantage:
                score = torch.sign(score)
        elif self.normalize_over_state == "Z":
            buffer_dist = self.buffer_policy(obs)
            K = self.Z_K
            buffer_obs = []
            buffer_actions = []
            log_bs = []
            log_pis = []
            for i in range(K):
                u = buffer_dist.sample()
                log_b = buffer_dist.log_prob(u)
                log_pi = dist.log_prob(u)
                buffer_obs.append(obs)
                buffer_actions.append(u)
                log_bs.append(log_b)
                log_pis.append(log_pi)
            buffer_obs = torch.cat(buffer_obs, 0)
            buffer_actions = torch.cat(buffer_actions, 0)
            p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, ))
            log_pi = torch.cat(log_pis, 0)
            log_pi = log_pi.sum(dim=1, )
            q1_b = self.qf1(buffer_obs, buffer_actions)
            q2_b = self.qf2(buffer_obs, buffer_actions)
            q_b = torch.min(q1_b, q2_b)
            q_b = torch.reshape(q_b, (-1, K))
            adv_b = q_b - v_pi
            # if self._n_train_steps_total % 100 == 0:
            #     import ipdb; ipdb.set_trace()
            # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True)
            # score = torch.exp((q_adv - v_pi) / beta) / Z
            # score = score / sum(score)
            logK = torch.log(ptu.tensor(float(K)))
            logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True)
            logS = (q_adv - v_pi) / beta - logZ
            # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True)
            # logS = q_adv/beta - logZ
            score = F.softmax(logS, dim=0)  # score / sum(score)
        else:
            error

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        if self.weight_loss and weights is None:
            if self.normalize_over_batch:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif not self.normalize_over_batch:
                weights = score
            else:
                error
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0:
            del self.buffer_policy_optimizer
            self.buffer_policy_optimizer = self.optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=self.policy_weight_decay,
                lr=self.policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            for i in range(self.num_buffer_policy_train_steps_on_reset):
                if self.train_bc_on_rl_buffer:
                    if self.advantage_weighted_buffer_loss:
                        buffer_dist = self.buffer_policy(obs)
                        buffer_u = actions
                        buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob(
                        )
                        buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                        buffer_policy_logpp = buffer_policy_logpp[:, None]

                        buffer_q1_pred = self.qf1(obs, buffer_u)
                        buffer_q2_pred = self.qf2(obs, buffer_u)
                        buffer_q_adv = torch.min(buffer_q1_pred,
                                                 buffer_q2_pred)

                        buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                        buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                        buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                        buffer_score = buffer_q_adv - buffer_v_pi
                        buffer_weights = F.softmax(buffer_score / beta, dim=0)
                        buffer_policy_loss = self.awr_weight * (
                            -buffer_policy_logpp * len(buffer_weights) *
                            buffer_weights.detach()).mean()
                    else:
                        buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                            self.replay_buffer.train_replay_buffer,
                            self.buffer_policy)

                    self.buffer_policy_optimizer.zero_grad()
                    buffer_policy_loss.backward(retain_graph=True)
                    self.buffer_policy_optimizer.step()

        if self.train_bc_on_rl_buffer:
            if self.advantage_weighted_buffer_loss:
                buffer_dist = self.buffer_policy(obs)
                buffer_u = actions
                buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob()
                buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                buffer_policy_logpp = buffer_policy_logpp[:, None]

                buffer_q1_pred = self.qf1(obs, buffer_u)
                buffer_q2_pred = self.qf2(obs, buffer_u)
                buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred)

                buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                buffer_score = buffer_q_adv - buffer_v_pi
                buffer_weights = F.softmax(buffer_score / beta, dim=0)
                buffer_policy_loss = self.awr_weight * (
                    -buffer_policy_logpp * len(buffer_weights) *
                    buffer_weights.detach()).mean()
            else:
                buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0:
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))

            if self.normalize_over_state == "Z":
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'logZ',
                        ptu.get_numpy(logZ),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                _, buffer_train_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)

                _, buffer_test_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.validation_replay_buffer,
                    self.buffer_policy)
                buffer_dist = self.buffer_policy(obs)
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                _, train_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_train_buffer, self.buffer_policy)

                _, test_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.buffer_policy)

                self.eval_statistics.update({
                    "buffer_policy/Train Online Logprob":
                    -1 * ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Online Logprob":
                    -1 * ptu.get_numpy(buffer_test_logp_loss),
                    "buffer_policy/Train Offline Logprob":
                    -1 * ptu.get_numpy(train_offline_logp_loss),
                    "buffer_policy/Test Offline Logprob":
                    -1 * ptu.get_numpy(test_offline_logp_loss),
                    "buffer_policy/train_policy_loss":
                    ptu.get_numpy(buffer_policy_loss),
                    # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss),
                    "buffer_policy/kl_div":
                    ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

            if self.validation_qlearning:
                train_data = self.replay_buffer.validation_replay_buffer.random_batch(
                    self.bc_batch_size)
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                # goals = train_data['resampled_goals']
                train_data[
                    'observations'] = obs  # torch.cat((obs, goals), dim=1)
                train_data[
                    'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
                self.test_from_torch(train_data)

        self._n_train_steps_total += 1
Exemple #17
0
    def compute_loss(
        self,
        batch,
        skip_statistics=False,
    ) -> Tuple[SACLosses, LossStatistics]:
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(-1)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.unsqueeze(-1)
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = (self.reward_scale * rewards +
                    (1. - terminals) * self.discount * target_q_values)
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Save some statistics for eval
        """
        eval_statistics = OrderedDict()
        if not skip_statistics:
            eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            reward_scale = self.reward_scale
            if isinstance(reward_scale, torch.Tensor):
                reward_scale = ptu.get_numpy(reward_scale)
            eval_statistics['reward scale'] = reward_scale
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            eval_statistics.update(policy_statistics)
            if self.use_automatic_entropy_tuning:
                eval_statistics['Alpha'] = alpha.item()
                eval_statistics['Alpha Loss'] = alpha_loss.item()

        loss = SACLosses(
            policy_loss=policy_loss,
            qf1_loss=qf1_loss,
            qf2_loss=qf2_loss,
            alpha_loss=alpha_loss,
        )

        return loss, eval_statistics
 def get_diagnostics(self):
     diagnostics = OrderedDict()
     for name, collector in self.path_collectors.items():
         diagnostics.update(
             add_prefix(collector.get_diagnostics(), name, divider='/'), )
     return diagnostics