def test_runs_with_uniform_policies(self, game_name): game = pyspiel.load_game(game_name) calc = action_value.TreeWalkCalculator(game) calc.compute_all_states_action_values([ policy.PolicyFromCallable(game, _uniform_policy), policy.PolicyFromCallable(game, _uniform_policy) ])
def __init__(self, game): if game.num_players() != 2: raise ValueError("Only supports 2-player games.") self.game = game self._num_players = game.num_players() self._num_actions = game.num_distinct_actions() self._action_value_calculator = action_value.TreeWalkCalculator(game) # best_responder[i] is a best response to the provided policy for player i. # It is therefore a policy for player (1-i). self._best_responder = {0: None, 1: None} self._all_states = None
def test_kuhn_poker_always_pass_p0(self): game = pyspiel.load_game("kuhn_poker") calc = action_value.TreeWalkCalculator(game) uniform_policy = policy.TabularPolicy(game) always_pass_policy = policy.FirstActionPolicy(game).to_tabular() returned_values = calc([always_pass_policy, uniform_policy], always_pass_policy) # Action 0 == Pass. Action 1 == Bet # Some values are 0 because the states are not reached, thus the expected # value of that node is undefined. np.testing.assert_array_almost_equal( np.asarray([ # Player 0 states [-1.0, -0.5], # '0' [-1.0, -2.0], # '0pb' [-0.5, 0.5], # '1' [-1.0, 0.0], # '1pb' [0.0, 1.5], # '2' [-1.0, 2.0], # '2pb' # Player 1 states [0.0, 1.0], # '1p' [0, 0], # Unreachable [1.0, 1.0], # '2p' [0, 0], # Unreachable [-1.0, 1.0], # '0p' [0, 0], # Unreachable ]), returned_values.action_values) np.testing.assert_array_almost_equal( np.asarray([ # Player 0 states 1 / 3, # '0' 1 / 6, # '0pb' 1 / 3, # '1' 1 / 6, # '1pb' 1 / 3, # '2' 1 / 6, # '2pb' # Player 1 states 1 / 3, # '1p' 0.0, # '1b': zero because player 0 always play pass 1 / 3, # 2p' 0.0, # '2b': zero because player 0 always play pass 1 / 3, # '0p' 0.0, # '0b': zero because player 0 always play pass ]), returned_values.counterfactual_reach_probs) # The reach probabilities are always one, even though we have player 0 # who only plays pass, because the unreachable nodes for player 0 are # terminal nodes: e.g. 'x x b b p' has a player 0 reach of 0, but it is # a terminal node, thus it does not appear in the tabular policy # states. np.testing.assert_array_equal( [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], returned_values.player_reach_probs) np.testing.assert_array_almost_equal( np.asarray([ np.array([-1/3, -1/6]), np.array([-1/6, -1/3]), np.array([-1/6, 1/6]), np.array([-1/6, 0.]), np.array([0., 0.5]), np.array([-1/6, 1/3]), np.array([0., 1/3]), np.array([0., 0.]), np.array([1/3, 1/3]), np.array([0., 0.]), np.array([-1/3, 1/3]), np.array([0., 0.]) ]), returned_values.sum_cfr_reach_by_action_value)
def test_runs_with_uniform_policies(self, game_name): game = pyspiel.load_game(game_name) calc = action_value.TreeWalkCalculator(game) uniform_policy = policy.TabularPolicy(game) calc.compute_all_states_action_values([uniform_policy, uniform_policy])
def test_kuhn_poker_always_pass_p0(self): game = pyspiel.load_game("kuhn_poker") calc = action_value.TreeWalkCalculator(game) for always_pass_policy in [ lambda state: [(0, 1.0), (1, 0.0)], # On purpose, we use a policy that do not list all the legal actions. lambda state: [(0, 1.0), (1, 0.0)], ]: tabular_policy = policy.tabular_policy_from_policy( game, policy.PolicyFromCallable(game, always_pass_policy)) # States are ordered using tabular_policy.states_per_player: # ['0', '0pb', '1', '1pb', '2', '2pb'] + # ['1p', '1b', '2p', '2b', '0p', '0b'] np.testing.assert_array_equal( np.asarray([ [1., 0.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], ]), tabular_policy.action_probability_array) returned_values = calc([ policy.PolicyFromCallable(game, always_pass_policy), policy.PolicyFromCallable(game, _uniform_policy) ], tabular_policy) # Action 0 == Pass. Action 1 == Bet # Some values are 0 because the states are not reached, thus the expected # value of that node is undefined. np.testing.assert_array_almost_equal( np.asarray([ [-1.0, -0.5], [-1.0, -2.0], [-0.5, 0.5], [-1.0, 0.0], [0.0, 1.5], [-1.0, 2.0], [0.0, 1.0], [0, 0], [1.0, 1.0], [0, 0], [-1.0, 1.0], [0, 0], ]), returned_values.action_values) np.testing.assert_array_almost_equal( np.asarray([ # Player 0 states 1 / 3, # '0' 1 / 6, # '0pb' 1 / 3, # '1' 1 / 6, # '1pb' 1 / 3, # '2' 1 / 6, # '2pb' # Player 1 states 1 / 3, # '1p' 0.0, # '1b': zero because player 0 always play pass 1 / 3, # 2p' 0.0, # '2b': zero because player 0 always play pass 1 / 3, # '0p' 0.0, # '0b': zero because player 0 always play pass ]), returned_values.counterfactual_reach_probs) # The reach probabilities are always one, even though we have player 0 # who only plays pass, because the unreachable nodes for player 0 are # terminal nodes: e.g. 'x x b b p' has a player 0 reach of 0, but it is # a terminal node, thus it does not appear in the tabular policy # states. np.testing.assert_array_equal( [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], returned_values.player_reach_probs) np.testing.assert_array_almost_equal( np.asarray([ np.array([-1 / 3, -1 / 6]), np.array([-1 / 6, -1 / 3]), np.array([-1 / 6, 1 / 6]), np.array([-1 / 6, 0.]), np.array([0., 0.5]), np.array([-1 / 6, 1 / 3]), np.array([0., 1 / 3]), np.array([0., 0.]), np.array([1 / 3, 1 / 3]), np.array([0., 0.]), np.array([-1 / 3, 1 / 3]), np.array([0., 0.]) ]), returned_values.sum_cfr_reach_by_action_value)
def test_runs_with_uniform_policies(self, game_name, num_players): game = pyspiel.load_game( game_name, {"players": pyspiel.GameParameter(num_players)}) calc = action_value.TreeWalkCalculator(game) uniform_policy = policy.TabularPolicy(game) calc.compute_all_states_action_values([uniform_policy] * num_players)