def test_outcome_sampling_kuhn_3p(self):
     np.random.seed(SEED)
     game = pyspiel.load_game("kuhn_poker", {"players": 3})
     os_solver = outcome_sampling_mccfr.OutcomeSamplingSolver(game)
     for _ in range(10000):
         os_solver.iteration()
     conv = exploitability.nash_conv(game, os_solver.average_policy())
     print("Kuhn3P, conv = {}".format(conv))
     self.assertLess(conv, 0.22)
Exemplo n.º 2
0
 def test_int_mccfr_on_turn_based_game_with_exploitability(self):
   """Check if outcome sampling MCCFR can be applied."""
   game = pyspiel.load_game(
       "python_dynamic_routing(max_num_time_step=5,time_step_length=1.0)")
   seq_game = pyspiel.convert_to_turn_based(game)
   cfr_solver = outcome_mccfr.OutcomeSamplingSolver(seq_game)
   for _ in range(_NUM_ITERATION_CFR_TEST):
     cfr_solver.iteration()
   exploitability.nash_conv(seq_game, cfr_solver.average_policy())
    def test_outcome_sampling_leduc_2p(self):
        np.random.seed(SEED)
        game = pyspiel.load_game("leduc_poker")
        os_solver = outcome_sampling_mccfr.OutcomeSamplingSolver(game)
        for _ in range(10000):
            os_solver.iteration()
        conv = exploitability.nash_conv(game, os_solver.average_policy())
        print("Leduc2P, conv = {}".format(conv))

        self.assertLess(conv, 3.07)
Exemplo n.º 4
0
 def test_outcome_sampling_kuhn_2p(self):
     np.random.seed(SEED)
     game = pyspiel.load_game("kuhn_poker")
     os_solver = outcome_sampling_mccfr.OutcomeSamplingSolver(game)
     for _ in range(1000):
         os_solver.iteration()
     conv = exploitability.nash_conv(
         game,
         policy.PolicyFromCallable(game, os_solver.callable_avg_policy()))
     print("Kuhn2P, conv = {}".format(conv))
     self.assertGreater(conv, 0.2)
     self.assertLess(conv, 0.3)
Exemplo n.º 5
0
def main(_):
    game = pyspiel.load_game(FLAGS.game, {"players": FLAGS.players})
    if FLAGS.sampling == "external":
        cfr_solver = external_mccfr.ExternalSamplingSolver(
            game, external_mccfr.AverageType.SIMPLE)
    else:
        cfr_solver = outcome_mccfr.OutcomeSamplingSolver(game)
    for i in range(FLAGS.iterations):
        cfr_solver.iteration()
        if i % FLAGS.print_freq == 0:
            conv = exploitability.nash_conv(game, cfr_solver.average_policy())
            print("Iteration {} exploitability {}".format(i, conv))
Exemplo n.º 6
0
 def test_outcome_sampling_leduc_2p(self):
   np.random.seed(SEED)
   game = pyspiel.load_game("leduc_poker")
   os_solver = outcome_sampling_mccfr.OutcomeSamplingSolver(game)
   for _ in range(1000):
     os_solver.iteration()
   conv = exploitability.nash_conv(
       game,
       policy.tabular_policy_from_callable(game,
                                           os_solver.callable_avg_policy()))
   print("Leduc2P, conv = {}".format(conv))
   self.assertGreater(conv, 4.5)
   self.assertLess(conv, 4.6)
 def test_outcome_sampling_kuhn_2p(self):
     np.random.seed(SEED)
     game = pyspiel.load_game("kuhn_poker")
     os_solver = outcome_sampling_mccfr.OutcomeSamplingSolver(game)
     for _ in range(10000):
         os_solver.iteration()
     conv = exploitability.nash_conv(game, os_solver.average_policy())
     print("Kuhn2P, conv = {}".format(conv))
     self.assertLess(conv, 0.17)
     # ensure that to_tabular() works on the returned policy
     # and the tabular policy is equivalent
     tabular_policy = os_solver.average_policy().to_tabular()
     conv2 = exploitability.nash_conv(game, tabular_policy)
     self.assertEqual(conv, conv2)
Exemplo n.º 8
0
 def test_outcome_sampling_kuhn_3p(self):
   np.random.seed(SEED)
   game = pyspiel.load_game("kuhn_poker",
                            {"players": pyspiel.GameParameter(3)})
   os_solver = outcome_sampling_mccfr.OutcomeSamplingSolver(game)
   for _ in range(1000):
     os_solver.iteration()
   conv = exploitability.nash_conv(
       game,
       policy.tabular_policy_from_callable(game,
                                           os_solver.callable_avg_policy()))
   print("Kuhn3P, conv = {}".format(conv))
   self.assertGreater(conv, 0.3)
   self.assertLess(conv, 0.4)