예제 #1
0
    def train_once(self, paths):
        """Train the algorithm once.

        Args:
            itr (int): Iteration number.
            paths (list[dict]): A list of collected paths.

        Returns:
            numpy.float64: Calculated mean value of undiscounted returns.

        """
        obs, actions, rewards, returns, valids, baselines = \
            self.process_samples(paths)

        if self._maximum_entropy:
            policy_entropies = self._compute_policy_entropy(obs)
            rewards += self._policy_ent_coeff * policy_entropies

        obs_flat = torch.cat(filter_valids(obs, valids))
        actions_flat = torch.cat(filter_valids(actions, valids))
        rewards_flat = torch.cat(filter_valids(rewards, valids))
        returns_flat = torch.cat(filter_valids(returns, valids))
        advs_flat = self._compute_advantage(rewards, valids, baselines)

        with torch.no_grad():
            policy_loss_before = self._compute_loss_with_adv(
                obs_flat, actions_flat, rewards_flat, advs_flat)
            vf_loss_before = self._compute_vf_loss(
                obs_flat, returns_flat)
            # kl_before = self._compute_kl_constraint(obs)
            kl_before = self._compute_kl_constraint(obs_flat)

        self._train(obs_flat, actions_flat, rewards_flat, returns_flat,
                    advs_flat)

        with torch.no_grad():
            policy_loss_after = self._compute_loss_with_adv(
                obs_flat, actions_flat, rewards_flat, advs_flat)
            vf_loss_after = self._compute_vf_loss(
                obs_flat, returns_flat)
            # kl_after = self._compute_kl_constraint(obs)
            kl_after = self._compute_kl_constraint(obs_flat)
            # policy_entropy = self._compute_policy_entropy(obs)
            policy_entropy = self._compute_policy_entropy(obs_flat)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['LossBefore'] = policy_loss_before.item()
            self.eval_statistics['LossAfter'] = policy_loss_after.item()
            self.eval_statistics['dLoss'] = (policy_loss_before - policy_loss_after).item()
            self.eval_statistics['KLBefore'] = kl_before.item()
            self.eval_statistics['KL'] = kl_after.item()
            self.eval_statistics['Entropy'] = policy_entropy.mean().item()

            self.eval_statistics['VF LossBefore'] = vf_loss_before.item()
            self.eval_statistics['VF LossAfter'] = vf_loss_after.item()
            self.eval_statistics['VF dLoss'] = (vf_loss_before - vf_loss_after).item()

        self._old_policy = copy.deepcopy(self.policy)
예제 #2
0
    def _compute_advantage(self, rewards, valids, baselines):
        r"""Compute mean value of loss.

        Notes: P is the maximum path length (self.max_path_length)

        Args:
            rewards (torch.Tensor): Acquired rewards
                with shape :math:`(N, P)`.
            valids (list[int]): Numbers of valid steps in each paths
            baselines (torch.Tensor): Value function estimation at each step
                with shape :math:`(N, P)`.

        Returns:
            torch.Tensor: Calculated advantage values given rewards and
                baselines with shape :math:`(N \dot [T], )`.

        """
        advantages = compute_advantages(self.discount, self._gae_lambda,
                                        self.max_path_length, baselines,
                                        rewards)
        advantage_flat = torch.cat(filter_valids(advantages, valids))

        if self._center_adv:
            means = advantage_flat.mean()
            variance = advantage_flat.var()
            advantage_flat = (advantage_flat - means) / (variance + 1e-8)

        if self._positive_adv:
            advantage_flat -= advantage_flat.min()

        return advantage_flat
예제 #3
0
파일: irl_ppo.py 프로젝트: maxiaoba/rlkit
    def _compute_advantage(self, rewards, valids, baselines):
        advantages = compute_advantages(self.discount, self._gae_lambda,
                                        self.max_path_length, baselines,
                                        rewards)
        advantages_flat = torch.cat(filter_valids(advantages, valids))

        if self._center_adv:
            means = advantages_flat.mean()
            variance = advantages_flat.var()
            advantages = (advantages - means) / (variance + 1e-8)

        if self._positive_adv:
            advantages -= advantages.min()

        return advantages
예제 #4
0
파일: coma.py 프로젝트: maxiaoba/rlkit
    def train_once(self, paths):
        """Train the algorithm once.

        Args:
            itr (int): Iteration number.
            paths (list[dict]): A list of collected paths.

        Returns:
            numpy.float64: Calculated mean value of undiscounted returns.

        """
        obs_n, actions_n, rewards_n, returns_n, valids, raw_actions_n = \
            self.process_samples(paths)
        # num_path x path_lenth x num_agent x dim
        num_agent = obs_n.shape[2]
        if self._entropy_reward:
            for agent in range(num_agent):
                policy_entropies = self._compute_policy_entropy(
                    self.policy_n[agent], obs_n[:, :, agent, :])
                rewards_n[:, :,
                          agent, :] += self._policy_ent_coeff * policy_entropies

        obs_input = torch.cat(filter_valids(obs_n, valids))
        actions_input = torch.cat(filter_valids(actions_n, valids))
        if raw_actions_n is None:
            raw_actions_input = None
        else:
            raw_actions_input = torch.cat(filter_valids(raw_actions_n, valids))

        rewards_input = torch.cat(filter_valids(rewards_n, valids))
        returns_input = torch.cat(filter_valids(returns_n, valids))
        policy_input = obs_input
        valid_mask = torch.ones(obs_input.shape[0]).bool()
        # num of valid samples x num_agent x dim
        advs_input = self._compute_advantage(obs_input, actions_input,
                                             returns_input)

        policy_loss_before_n, qf_loss_before_n, kl_before_n = [], [], []
        for agent in range(num_agent):
            with torch.no_grad():
                policy_loss_before = self._compute_loss_with_adv(
                    self.policy_n[agent],
                    self._old_policy_n[agent],
                    policy_input[:, agent, :],
                    actions_input[:, agent, :],
                    rewards_input[:, agent, :],
                    advs_input[:, agent, :],
                    valid_mask,
                    raw_actions=(None if (raw_actions_input is None) else
                                 raw_actions_input[:, agent, :]))
                qf_loss_before = self._compute_qf_loss(
                    self.qf_n[agent], obs_input, actions_input,
                    returns_input[:, agent, :], valid_mask)
                # kl_before = self._compute_kl_constraint(obs)
                kl_before = self._compute_kl_constraint(
                    self.policy_n[agent], self._old_policy_n[agent],
                    policy_input[:, agent, :], valid_mask)
                policy_loss_before_n.append(policy_loss_before)
                qf_loss_before_n.append(qf_loss_before)
                kl_before_n.append(kl_before)

        self._train(policy_input, obs_input, actions_input, raw_actions_input,
                    rewards_input, returns_input, advs_input, valid_mask)

        policy_loss_after_n, qf_loss_after_n, kl_after_n, policy_entropy_n = [], [], [], []
        for agent in range(num_agent):
            with torch.no_grad():
                policy_loss_after = self._compute_loss_with_adv(
                    self.policy_n[agent],
                    self._old_policy_n[agent],
                    policy_input[:, agent, :],
                    actions_input[:, agent, :],
                    rewards_input[:, agent, :],
                    advs_input[:, agent, :],
                    valid_mask,
                    raw_actions=(None if (raw_actions_input is None) else
                                 raw_actions_input[:, agent, :]))
                qf_loss_after = self._compute_qf_loss(
                    self.qf_n[agent], obs_input, actions_input,
                    returns_input[:, agent, :], valid_mask)
                # kl_before = self._compute_kl_constraint(obs)
                kl_after = self._compute_kl_constraint(
                    self.policy_n[agent], self._old_policy_n[agent],
                    policy_input[:, agent, :], valid_mask)
                policy_entropy = self._compute_policy_entropy(
                    self.policy_n[agent], policy_input[:, agent, :])
                policy_loss_after_n.append(policy_loss_after)
                qf_loss_after_n.append(qf_loss_after)
                kl_after_n.append(kl_after)
                policy_entropy_n.append(policy_entropy)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            for agent in range(num_agent):
                self.eval_statistics['LossBefore {}'.format(
                    agent)] = policy_loss_before_n[agent].item()
                self.eval_statistics['LossAfter {}'.format(
                    agent)] = policy_loss_after_n[agent].item()
                self.eval_statistics['dLoss {}'.format(agent)] = (
                    policy_loss_before_n[agent] -
                    policy_loss_after_n[agent]).item()
                self.eval_statistics['KLBefore {}'.format(
                    agent)] = kl_before_n[agent].item()
                self.eval_statistics['KL {}'.format(
                    agent)] = kl_after_n[agent].item()
                self.eval_statistics['Entropy {}'.format(
                    agent)] = policy_entropy_n[agent][valid_mask].mean().item(
                    )

                self.eval_statistics['QF LossBefore {}'.format(
                    agent)] = qf_loss_before_n[agent].item()
                self.eval_statistics['QF LossAfter {}'.format(
                    agent)] = qf_loss_after_n[agent].item()
                self.eval_statistics['QF dLoss {}'.format(agent)] = (
                    qf_loss_before_n[agent] - qf_loss_after_n[agent]).item()

        self._old_policy_n = copy.deepcopy(self.policy_n)
예제 #5
0
    def train_once(self, paths):
        """Train the algorithm once.

        Args:
            itr (int): Iteration number.
            paths (list[dict]): A list of collected paths.

        Returns:
            numpy.float64: Calculated mean value of undiscounted returns.

        """

        obs, actions, rewards, returns, valids, baselines, labels = \
            self.process_samples(paths)

        if self._maximum_entropy:
            policy_entropies = self._compute_policy_entropy(obs)
            rewards += self._policy_ent_coeff * policy_entropies
        advs = self._compute_advantage(rewards, valids, baselines)

        if self._recurrent:
            pre_actions = actions[:, :-1, :]
            policy_input = (obs, pre_actions)
            obs_input, actions_input, rewards_input, returns_input, advs_input = \
                obs, actions, rewards, returns, advs
            labels_input = labels
            valid_mask = torch.zeros(obs.shape[0], obs.shape[1]).bool()
            for i, valid in enumerate(valids):
                valid_mask[i, :valid] = True
        else:
            obs_input = torch.cat(filter_valids(obs, valids))
            actions_input = torch.cat(filter_valids(actions, valids))
            rewards_input = torch.cat(filter_valids(rewards, valids))
            returns_input = torch.cat(filter_valids(returns, valids))
            advs_input = torch.cat(filter_valids(advs, valids))
            labels_input = torch.cat(filter_valids(labels, valids))
            policy_input = obs_input
            valid_mask = torch.ones(obs_input.shape[0]).bool()
            # (num of valid samples) x ...

        with torch.no_grad():
            policy_loss_before = self._compute_loss_with_adv(
                policy_input, actions_input, rewards_input, advs_input,
                labels_input, valid_mask)
            vf_loss_before = self._compute_vf_loss(obs_input, returns_input,
                                                   valid_mask)
            # kl_before = self._compute_kl_constraint(obs)
            kl_before = self._compute_kl_constraint(policy_input, valid_mask)
            sup_loss_before, sup_accuracy_before = self._compute_sup_loss(
                obs_input, actions_input, labels_input, valid_mask)

        self._train(policy_input, obs_input, actions_input, rewards_input,
                    returns_input, advs_input, labels_input, valid_mask)

        with torch.no_grad():
            policy_loss_after = self._compute_loss_with_adv(
                policy_input, actions_input, rewards_input, advs_input,
                labels_input, valid_mask)
            vf_loss_after = self._compute_vf_loss(obs_input, returns_input,
                                                  valid_mask)
            # kl_before = self._compute_kl_constraint(obs)
            kl_after = self._compute_kl_constraint(policy_input, valid_mask)
            sup_loss_after, sup_accuracy_after = self._compute_sup_loss(
                obs_input, actions_input, labels_input, valid_mask)
            policy_entropy = self._compute_policy_entropy(policy_input)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['LossBefore'] = policy_loss_before.item()
            self.eval_statistics['LossAfter'] = policy_loss_after.item()
            self.eval_statistics['dLoss'] = (policy_loss_before -
                                             policy_loss_after).item()
            self.eval_statistics['KLBefore'] = kl_before.item()
            self.eval_statistics['KL'] = kl_after.item()
            self.eval_statistics['Entropy'] = policy_entropy[valid_mask].mean(
            ).item()

            self.eval_statistics['VF LossBefore'] = vf_loss_before.item()
            self.eval_statistics['VF LossAfter'] = vf_loss_after.item()
            self.eval_statistics['VF dLoss'] = (vf_loss_before -
                                                vf_loss_after).item()

            self.eval_statistics['SUP LossBefore'] = sup_loss_before.item()
            self.eval_statistics['SUP LossAfter'] = sup_loss_after.item()
            self.eval_statistics['SUP dLoss'] = (sup_loss_before -
                                                 sup_loss_after).item()
            self.eval_statistics[
                'SUP AccuracyBefore'] = sup_accuracy_before.item()
            self.eval_statistics[
                'SUP AccuracyAfter'] = sup_accuracy_after.item()
            self.eval_statistics['SUP dAccuracy'] = (
                sup_accuracy_before - sup_accuracy_after).item()

        self._old_policy = copy.deepcopy(self.policy)
예제 #6
0
파일: ppo_sup_0.py 프로젝트: maxiaoba/rlkit
    def train_once(self, paths):
        """Train the algorithm once.

        Args:
            itr (int): Iteration number.
            paths (list[dict]): A list of collected paths.

        Returns:
            numpy.float64: Calculated mean value of undiscounted returns.

        """
        obs, actions, rewards, returns, valids, baselines, n_labels = \
            self.process_samples(paths)

        if self._maximum_entropy:
            policy_entropies = self._compute_policy_entropy(obs)
            rewards += self._policy_ent_coeff * policy_entropies

        obs_flat = torch.cat(filter_valids(obs, valids))
        actions_flat = torch.cat(filter_valids(actions, valids))
        rewards_flat = torch.cat(filter_valids(rewards, valids))
        returns_flat = torch.cat(filter_valids(returns, valids))
        advs_flat = self._compute_advantage(rewards, valids, baselines)
        n_labels_flat = [
            torch.cat(filter_valids(labels, valids)) for labels in n_labels
        ]

        self.replay_buffer.add_batch(obs_flat, n_labels_flat)
        for _ in range(self.sup_train_num):
            batch = self.replay_buffer.random_batch(self.sup_batch_size)
            sup_losses = self._train_sup_learners(batch['observations'],
                                                  batch['n_labels'])

        with torch.no_grad():
            policy_loss_before = self._compute_loss_with_adv(
                obs_flat, actions_flat, rewards_flat, advs_flat)
            vf_loss_before = self._compute_vf_loss(obs_flat, returns_flat)
            # kl_before = self._compute_kl_constraint(obs)
            kl_before = self._compute_kl_constraint(obs_flat)

        self._train(obs_flat, actions_flat, rewards_flat, returns_flat,
                    advs_flat)

        with torch.no_grad():
            policy_loss_after = self._compute_loss_with_adv(
                obs_flat, actions_flat, rewards_flat, advs_flat)
            vf_loss_after = self._compute_vf_loss(obs_flat, returns_flat)
            # kl_after = self._compute_kl_constraint(obs)
            kl_after = self._compute_kl_constraint(obs_flat)
            # policy_entropy = self._compute_policy_entropy(obs)
            policy_entropy = self._compute_policy_entropy(obs_flat)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['LossBefore'] = policy_loss_before.item()
            self.eval_statistics['LossAfter'] = policy_loss_after.item()
            self.eval_statistics['dLoss'] = (policy_loss_before -
                                             policy_loss_after).item()
            self.eval_statistics['KLBefore'] = kl_before.item()
            self.eval_statistics['KL'] = kl_after.item()
            self.eval_statistics['Entropy'] = policy_entropy.mean().item()

            self.eval_statistics['VF LossBefore'] = vf_loss_before.item()
            self.eval_statistics['VF LossAfter'] = vf_loss_after.item()
            self.eval_statistics['VF dLoss'] = (vf_loss_before -
                                                vf_loss_after).item()
            for i in range(len(sup_losses)):
                self.eval_statistics['SUP Loss {}'.format(i)] = sup_losses[i]

        self._old_policy = copy.deepcopy(self.policy)