Esempio n. 1
0
    def test_runs_with_uniform_policies(self, game_name):
        game = pyspiel.load_game(game_name)
        calc = action_value.TreeWalkCalculator(game)

        calc.compute_all_states_action_values([
            policy.PolicyFromCallable(game, _uniform_policy),
            policy.PolicyFromCallable(game, _uniform_policy)
        ])
Esempio n. 2
0
  def __init__(self, game):
    if game.num_players() != 2:
      raise ValueError("Only supports 2-player games.")
    self.game = game
    self._num_players = game.num_players()
    self._num_actions = game.num_distinct_actions()

    self._action_value_calculator = action_value.TreeWalkCalculator(game)
    # best_responder[i] is a best response to the provided policy for player i.
    # It is therefore a policy for player (1-i).
    self._best_responder = {0: None, 1: None}
    self._all_states = None
Esempio n. 3
0
  def test_kuhn_poker_always_pass_p0(self):
    game = pyspiel.load_game("kuhn_poker")
    calc = action_value.TreeWalkCalculator(game)
    uniform_policy = policy.TabularPolicy(game)
    always_pass_policy = policy.FirstActionPolicy(game).to_tabular()
    returned_values = calc([always_pass_policy, uniform_policy],
                           always_pass_policy)

    # Action 0 == Pass. Action 1 == Bet
    # Some values are 0 because the states are not reached, thus the expected
    # value of that node is undefined.
    np.testing.assert_array_almost_equal(
        np.asarray([
            # Player 0 states
            [-1.0, -0.5],    # '0'
            [-1.0, -2.0],    # '0pb'
            [-0.5, 0.5],     # '1'
            [-1.0, 0.0],     # '1pb'
            [0.0, 1.5],      # '2'
            [-1.0, 2.0],     # '2pb'
            # Player 1 states
            [0.0, 1.0],      # '1p'
            [0, 0],          # Unreachable
            [1.0, 1.0],      # '2p'
            [0, 0],          # Unreachable
            [-1.0, 1.0],     # '0p'
            [0, 0],          # Unreachable
        ]), returned_values.action_values)

    np.testing.assert_array_almost_equal(
        np.asarray([
            # Player 0 states
            1 / 3,  # '0'
            1 / 6,  # '0pb'
            1 / 3,  # '1'
            1 / 6,  # '1pb'
            1 / 3,  # '2'
            1 / 6,  # '2pb'
            # Player 1 states
            1 / 3,  # '1p'
            0.0,  # '1b': zero because player 0 always play pass
            1 / 3,  # 2p'
            0.0,  # '2b': zero because player 0 always play pass
            1 / 3,  # '0p'
            0.0,  # '0b':  zero because player 0 always play pass
        ]),
        returned_values.counterfactual_reach_probs)

    # The reach probabilities are always one, even though we have player 0
    # who only plays pass, because the unreachable nodes for player 0 are
    # terminal nodes: e.g.  'x x b b p' has a player 0 reach of 0, but it is
    # a terminal node, thus it does not appear in the tabular policy
    # states.
    np.testing.assert_array_equal(
        [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
        returned_values.player_reach_probs)

    np.testing.assert_array_almost_equal(
        np.asarray([
            np.array([-1/3, -1/6]),
            np.array([-1/6, -1/3]),
            np.array([-1/6, 1/6]),
            np.array([-1/6, 0.]),
            np.array([0., 0.5]),
            np.array([-1/6, 1/3]),
            np.array([0., 1/3]),
            np.array([0., 0.]),
            np.array([1/3, 1/3]),
            np.array([0., 0.]),
            np.array([-1/3, 1/3]),
            np.array([0., 0.])
        ]), returned_values.sum_cfr_reach_by_action_value)
Esempio n. 4
0
 def test_runs_with_uniform_policies(self, game_name):
   game = pyspiel.load_game(game_name)
   calc = action_value.TreeWalkCalculator(game)
   uniform_policy = policy.TabularPolicy(game)
   calc.compute_all_states_action_values([uniform_policy, uniform_policy])
Esempio n. 5
0
    def test_kuhn_poker_always_pass_p0(self):
        game = pyspiel.load_game("kuhn_poker")
        calc = action_value.TreeWalkCalculator(game)

        for always_pass_policy in [
                lambda state: [(0, 1.0), (1, 0.0)],
                # On purpose, we use a policy that do not list all the legal actions.
                lambda state: [(0, 1.0), (1, 0.0)],
        ]:
            tabular_policy = policy.tabular_policy_from_policy(
                game, policy.PolicyFromCallable(game, always_pass_policy))

            # States are ordered using tabular_policy.states_per_player:
            # ['0', '0pb', '1', '1pb', '2', '2pb'] +
            # ['1p', '1b', '2p', '2b', '0p', '0b']
            np.testing.assert_array_equal(
                np.asarray([
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                    [1., 0.],
                ]), tabular_policy.action_probability_array)

            returned_values = calc([
                policy.PolicyFromCallable(game, always_pass_policy),
                policy.PolicyFromCallable(game, _uniform_policy)
            ], tabular_policy)

            # Action 0 == Pass. Action 1 == Bet
            # Some values are 0 because the states are not reached, thus the expected
            # value of that node is undefined.
            np.testing.assert_array_almost_equal(
                np.asarray([
                    [-1.0, -0.5],
                    [-1.0, -2.0],
                    [-0.5, 0.5],
                    [-1.0, 0.0],
                    [0.0, 1.5],
                    [-1.0, 2.0],
                    [0.0, 1.0],
                    [0, 0],
                    [1.0, 1.0],
                    [0, 0],
                    [-1.0, 1.0],
                    [0, 0],
                ]), returned_values.action_values)

            np.testing.assert_array_almost_equal(
                np.asarray([
                    # Player 0 states
                    1 / 3,  # '0'
                    1 / 6,  # '0pb'
                    1 / 3,  # '1'
                    1 / 6,  # '1pb'
                    1 / 3,  # '2'
                    1 / 6,  # '2pb'
                    # Player 1 states
                    1 / 3,  # '1p'
                    0.0,  # '1b': zero because player 0 always play pass
                    1 / 3,  # 2p'
                    0.0,  # '2b': zero because player 0 always play pass
                    1 / 3,  # '0p'
                    0.0,  # '0b':  zero because player 0 always play pass
                ]),
                returned_values.counterfactual_reach_probs)

            # The reach probabilities are always one, even though we have player 0
            # who only plays pass, because the unreachable nodes for player 0 are
            # terminal nodes: e.g.  'x x b b p' has a player 0 reach of 0, but it is
            # a terminal node, thus it does not appear in the tabular policy
            # states.
            np.testing.assert_array_equal(
                [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                returned_values.player_reach_probs)

            np.testing.assert_array_almost_equal(
                np.asarray([
                    np.array([-1 / 3, -1 / 6]),
                    np.array([-1 / 6, -1 / 3]),
                    np.array([-1 / 6, 1 / 6]),
                    np.array([-1 / 6, 0.]),
                    np.array([0., 0.5]),
                    np.array([-1 / 6, 1 / 3]),
                    np.array([0., 1 / 3]),
                    np.array([0., 0.]),
                    np.array([1 / 3, 1 / 3]),
                    np.array([0., 0.]),
                    np.array([-1 / 3, 1 / 3]),
                    np.array([0., 0.])
                ]), returned_values.sum_cfr_reach_by_action_value)
Esempio n. 6
0
 def test_runs_with_uniform_policies(self, game_name, num_players):
     game = pyspiel.load_game(
         game_name, {"players": pyspiel.GameParameter(num_players)})
     calc = action_value.TreeWalkCalculator(game)
     uniform_policy = policy.TabularPolicy(game)
     calc.compute_all_states_action_values([uniform_policy] * num_players)