def disabled_test_external_sampling_liars_dice_2p_simple(self): np.random.seed(SEED) game = pyspiel.load_game("liars_dice") es_solver = external_sampling_mccfr.ExternalSamplingSolver( game, external_sampling_mccfr.AverageType.SIMPLE) for _ in range(1): es_solver.iteration() conv = exploitability.nash_conv(game, es_solver.average_policy()) print("Liar's dice, conv = {}".format(conv)) self.assertLess(conv, 2)
def test_external_sampling_kuhn_3p_simple(self): np.random.seed(SEED) game = pyspiel.load_game("kuhn_poker", {"players": 3}) es_solver = external_sampling_mccfr.ExternalSamplingSolver( game, external_sampling_mccfr.AverageType.SIMPLE) for _ in range(10): es_solver.iteration() conv = exploitability.nash_conv(game, es_solver.average_policy()) print("Kuhn3P, conv = {}".format(conv)) self.assertLess(conv, 2)
def test_ext_mccfr_on_turn_based_game_with_exploitability(self): """Check if external sampling MCCFR can be applied.""" game = pyspiel.load_game( "python_dynamic_routing(max_num_time_step=5,time_step_length=1.0)") seq_game = pyspiel.convert_to_turn_based(game) cfr_solver = external_mccfr.ExternalSamplingSolver( seq_game, external_mccfr.AverageType.SIMPLE) for _ in range(_NUM_ITERATION_CFR_TEST): cfr_solver.iteration() exploitability.nash_conv(seq_game, cfr_solver.average_policy())
def test_external_sampling_kuhn_2p_full(self): np.random.seed(SEED) game = pyspiel.load_game("kuhn_poker") es_solver = external_sampling_mccfr.ExternalSamplingSolver( game, external_sampling_mccfr.AverageType.FULL) for _ in range(10): es_solver.iteration() conv = exploitability.nash_conv(game, es_solver.average_policy()) print("Kuhn2P, conv = {}".format(conv)) self.assertLess(conv, 1)
def main(_): game = pyspiel.load_game(FLAGS.game, {"players": FLAGS.players}) if FLAGS.sampling == "external": cfr_solver = external_mccfr.ExternalSamplingSolver( game, external_mccfr.AverageType.SIMPLE) else: cfr_solver = outcome_mccfr.OutcomeSamplingSolver(game) for i in range(FLAGS.iterations): cfr_solver.iteration() if i % FLAGS.print_freq == 0: conv = exploitability.nash_conv(game, cfr_solver.average_policy()) print("Iteration {} exploitability {}".format(i, conv))
def test_external_sampling_kuhn_2p_simple(self): np.random.seed(SEED) game = pyspiel.load_game("kuhn_poker") es_solver = external_sampling_mccfr.ExternalSamplingSolver( game, external_sampling_mccfr.AverageType.SIMPLE) for _ in range(10): es_solver.iteration() conv = exploitability.nash_conv( game, policy.tabular_policy_from_callable(game, es_solver.callable_avg_policy())) print("Kuhn2P, conv = {}".format(conv)) self.assertLess(conv, 1)
def test_external_sampling_kuhn_3p_full(self): np.random.seed(SEED) game = pyspiel.load_game("kuhn_poker", {"players": pyspiel.GameParameter(3)}) es_solver = external_sampling_mccfr.ExternalSamplingSolver( game, external_sampling_mccfr.AverageType.FULL) for _ in range(10): es_solver.iteration() conv = exploitability.nash_conv( game, policy.tabular_policy_from_callable(game, es_solver.callable_avg_policy())) print("Kuhn3P, conv = {}".format(conv)) self.assertLess(conv, 2)
def test_external_sampling_leduc_2p_simple(self): np.random.seed(SEED) game = pyspiel.load_game("leduc_poker") es_solver = external_sampling_mccfr.ExternalSamplingSolver( game, external_sampling_mccfr.AverageType.SIMPLE) for _ in range(10): es_solver.iteration() conv = exploitability.nash_conv(game, es_solver.average_policy()) print("Leduc2P, conv = {}".format(conv)) self.assertLess(conv, 5) # ensure that to_tabular() works on the returned policy and # the tabular policy is equivalent tabular_policy = es_solver.average_policy().to_tabular() conv2 = exploitability.nash_conv(game, tabular_policy) self.assertEqual(conv, conv2)
def external_sampling_monte_carlo_counterfactual_regret_minimization( seq_game, number_of_iterations, compute_metrics=False): cfr_solver = external_mccfr.ExternalSamplingSolver( seq_game, external_mccfr.AverageType.SIMPLE) tick_time = time.time() # print("CFRSolver initialized.") for _ in range(number_of_iterations): cfr_solver.iteration() timing = time.time() - tick_time # print("Finish.") if compute_metrics: nash_conv = exploitability.nash_conv(seq_game, cfr_solver.average_policy()) return timing, cfr_solver.average_policy(), nash_conv return timing, cfr_solver.average_policy()