Пример #1
0
  def test__update_current_policy(self):
    game = pyspiel.load_game("kuhn_poker")
    tabular_policy = policy.TabularPolicy(game)

    cumulative_regrets = np.arange(0, 12 * 2).reshape((12, 2))
    expected_policy = cumulative_regrets / np.sum(
        cumulative_regrets, axis=-1, keepdims=True)
    nodes_indices = {
        u"0": 0,
        u"0pb": 1,
        u"1": 2,
        u"1pb": 3,
        u"2": 4,
        u"2pb": 5,
        u"1p": 6,
        u"1b": 7,
        u"2p": 8,
        u"2b": 9,
        u"0p": 10,
        u"0b": 11,
    }
    # pylint: disable=g-complex-comprehension
    info_state_nodes = {
        key: cfr._InfoStateNode(
            legal_actions=[0, 1],
            index_in_tabular_policy=None,
            cumulative_regret=dict(enumerate(cumulative_regrets[index])),
            cumulative_policy=None) for key, index in nodes_indices.items()
    }
    # pylint: enable=g-complex-comprehension

    cfr._update_current_policy(tabular_policy, info_state_nodes)

    np.testing.assert_array_equal(expected_policy,
                                  tabular_policy.action_probability_array)
Пример #2
0
 def evaluate_and_update_policy(self):
   """Performs a single step of policy evaluation and policy improvement."""
   self._iteration += 1
   if self._alternating_updates:
     for current_player in range(self._game.num_players()):
       self._compute_counterfactual_regret_for_player(
           self._root_node,
           policies=None,
           reach_probabilities=np.ones(self._game.num_players() + 1),
           player=current_player)
       for info_state in self._player_nodes[current_player]:
         for action in info_state.cumulative_regret.keys():
           if info_state.cumulative_regret[action] >= 0:
             info_state.cumulative_regret[action] *= (
                 self._iteration**self.alpha /
                 (self._iteration**self.alpha + 1))
           else:
             info_state.cumulative_regret[action] *= (
                 self._iteration**self.beta / (self._iteration**self.beta + 1))
       cfr._update_current_policy(self._current_policy, self._info_state_nodes)  # pylint: disable=protected-access