Exemplo n.º 1
0
def aggregate_joint_policies(game, total_policies,
                             probabilities_of_playing_policies):
    """Aggregate the players' joint policies.

  Specifically, returns a single callable policy object that is
  realization-equivalent to playing total_policies with
  probabilities_of_playing_policies. I.e., aggr_policy is a joint policy that
  can be called at any information state [via
  action_probabilities(state, player_id)].

  Args:
    game: The open_spiel game.
    total_policies: A list of list of all policy.Policy strategies used for
      training, where the n-th entry of the main list is a list of policies, one
      entry for each player.
    probabilities_of_playing_policies: A list of floats representing the
      probabilities of playing each joint strategy in total_policies.

  Returns:
    A callable object representing the policy.
  """
    aggregator = policy_aggregator_joint.JointPolicyAggregator(game)

    return aggregator.aggregate(range(len(total_policies[0])), total_policies,
                                probabilities_of_playing_policies)
    def test_policy_aggregation_random(self, game_name):
        env = rl_environment.Environment(game_name)
        num_players = 2
        num_joint_policies = 4

        joint_policies = [[
            policy.UniformRandomPolicy(env.game) for _ in range(num_players)
        ] for _ in range(num_joint_policies)]
        probabilities = np.ones(len(joint_policies))
        probabilities /= np.sum(probabilities)

        pol_ag = policy_aggregator_joint.JointPolicyAggregator(env.game)
        aggr_policy = pol_ag.aggregate([0, 1], joint_policies, probabilities)

        self.assertLen(aggr_policy.policies, num_players)
        for player in range(num_players):
            player_policy = aggr_policy.policies[player]
            self.assertNotEmpty(player_policy)
            for state_action_probs in player_policy.values():
                probs = list(state_action_probs.values())
                expected_prob = 1. / len(probs)
                for prob in probs:
                    self.assertEqual(expected_prob, prob)