示例#1
0
    def test_policy(self):
        game = pyspiel.load_game("kuhn_poker")
        solver = cfr.CFRPlusSolver(game)

        policy = solver.policy()
        self.assertLen(policy, 12)
        for values in policy.values():
            self.assertEqual({0: 0.5, 1: 0.5}, values)
示例#2
0
  def test_policy(self):
    game = pyspiel.load_game("kuhn_poker")
    solver = cfr.CFRPlusSolver(game)

    tabular_policy = solver.policy()
    self.assertLen(tabular_policy.state_lookup, 12)
    for info_state_str in tabular_policy.state_lookup.keys():
      np.testing.assert_equal(
          np.asarray([0.5, 0.5]), tabular_policy.policy_for_key(info_state_str))
示例#3
0
 def test_cfr_plus_kuhn_poker(self):
   game = pyspiel.load_game("kuhn_poker")
   cfr_solver = cfr.CFRPlusSolver(game)
   for _ in range(200):
     cfr_solver.evaluate_and_update_policy()
   average_policy = cfr_solver.average_policy()
   average_policy_values = expected_game_score.policy_value(
       game.new_initial_state(), [average_policy] * 2)
   # 1/18 is the Nash value. See https://en.wikipedia.org/wiki/Kuhn_poker
   np.testing.assert_allclose(
       average_policy_values, [-1 / 18, 1 / 18], atol=1e-3)
示例#4
0
 def test_cfr_plus_solver_best_response_mdp(self):
     game = pyspiel.load_game("kuhn_poker")
     cfr_solver = cfr.CFRPlusSolver(game)
     for _ in range(200):
         cfr_solver.evaluate_and_update_policy()
     average_policy = cfr_solver.average_policy()
     pyspiel_avg_policy = policy.python_policy_to_pyspiel_policy(
         average_policy)
     br_computer = pyspiel.TabularBestResponseMDP(game, pyspiel_avg_policy)
     br_info = br_computer.exploitability()
     self.assertLessEqual(br_info.exploitability, 0.001)
示例#5
0
def CFRPlus_Solving(game, iterations, save_every=0, save_prefix='base', alternating_updates = True, linear_averaging = True):
    def save_cfrplus():
        avg_policy = cfr_solver.average_policy()
        avg_policy = dict(zip(avg_policy.state_lookup, avg_policy.action_probability_array))
        policy_handler.save_to_tabular_policy(game, avg_policy, "policies/CFRPlus/{}/{}".format(save_prefix, it))

    cfr_solver = cfr.CFRPlusSolver(game)
    #cfr_solver = cfr._CFRSolver(game, regret_matching_plus=True, alternating_updates=alternating_updates, linear_averaging=linear_averaging)
    for it in range(iterations + 1):  # so that if you tell it to train 20K iterations, the last save isn't 19999

        if save_every != 0 and it % save_every == 0:  # order is important
            save_cfrplus()
        cfr_solver.evaluate_and_update_policy()
    save_cfrplus()