コード例 #1
0
ファイル: centralized_ma_ppo.py プロジェクト: parachutel/DICG
    def _compute_kl_constraint(self, obs, avail_actions, actions=None):
        """Compute KL divergence.

        Compute the KL divergence between the old policy distribution and
        current policy distribution.

        Args:
            obs (torch.Tensor): Observation from the environment.

        Returns:
            torch.Tensor: Calculated mean KL divergence.

        """
        if self.policy.recurrent:
            with torch.no_grad():
                if hasattr(self.policy, 'dicg'):
                    old_dist, _ = self._old_policy.forward(
                        obs, avail_actions, actions)
                else:
                    old_dist = self._old_policy.forward(
                        obs, avail_actions, actions)

            if hasattr(self.policy, 'dicg'):
                new_dist, _ = self.policy.forward(obs, avail_actions, actions)
            else:
                new_dist = self.policy.forward(obs, avail_actions, actions)

        else:
            flat_obs = flatten_batch(obs)
            flat_avail_actions = flatten_batch(avail_actions)
            with torch.no_grad():
                if hasattr(self.policy, 'dicg'):
                    old_dist, _ = self._old_policy.forward(
                        flat_obs, flat_avail_actions)
                else:
                    old_dist = self._old_policy.forward(
                        flat_obs, flat_avail_actions)

            if hasattr(self.policy, 'dicg'):
                new_dist, _ = self.policy.forward(flat_obs, flat_avail_actions)
            else:
                new_dist = self.policy.forward(flat_obs, flat_avail_actions)

        kl_constraint = torch.distributions.kl.kl_divergence(
            old_dist, new_dist)

        return kl_constraint.mean()
コード例 #2
0
    def _compute_kl_constraint(self, obs):
        """Compute KL divergence.

        Compute the KL divergence between the old policy distribution and
        current policy distribution.

        Args:
            obs (torch.Tensor): Observation from the environment.

        Returns:
            torch.Tensor: Calculated mean KL divergence.

        """
        flat_obs = flatten_batch(obs)
        with torch.no_grad():
            old_dist = self._old_policy.forward(flat_obs)

        new_dist = self.policy.forward(flat_obs)

        kl_constraint = torch.distributions.kl.kl_divergence(
            old_dist, new_dist)

        return kl_constraint.mean()