コード例 #1
0
    def test_sampling_in_simple_games(self):
        # The tests that have explicit time limit
        # can fail with very small probability.

        matrix_mp_num_states = 1 + 2 + 4
        game = pyspiel.load_game_as_turn_based("matrix_mp")
        for n in range(1, matrix_mp_num_states + 1):
            states = sample_some_states.sample_some_states(game, max_states=n)
            self.assertLen(states, n)

        states = sample_some_states.sample_some_states(game,
                                                       max_states=1,
                                                       depth_limit=0)
        self.assertLen(states, 1)

        states = sample_some_states.sample_some_states(
            game, max_states=matrix_mp_num_states + 1, time_limit=0.1)
        self.assertLen(states, matrix_mp_num_states)

        states = sample_some_states.sample_some_states(
            game,
            include_terminals=False,
            time_limit=0.1,
            max_states=matrix_mp_num_states)
        self.assertLen(states, 3)

        states = sample_some_states.sample_some_states(
            game,
            depth_limit=1,
            time_limit=0.1,
            max_states=matrix_mp_num_states)
        self.assertLen(states, 3)

        coordinated_mp_num_states = 1 + 2 + 4 + 8
        game = pyspiel.load_game_as_turn_based("coordinated_mp")
        for n in range(1, coordinated_mp_num_states + 1):
            states = sample_some_states.sample_some_states(game, max_states=n)
            self.assertLen(states, n)

        states = sample_some_states.sample_some_states(
            game, max_states=coordinated_mp_num_states + 1, time_limit=0.1)
        self.assertLen(states, coordinated_mp_num_states)

        states = sample_some_states.sample_some_states(
            game,
            max_states=coordinated_mp_num_states,
            include_chance_states=False,
            time_limit=0.1)
        self.assertLen(states, coordinated_mp_num_states - 2)
コード例 #2
0
    def __init__(self,
                 env_name,
                 env_seed=2,
                 deltas=None,
                 slow_oracle_kargs=None,
                 fast_oracle_kargs=None):

        # initialize rl environment.
        from open_spiel.python import rl_environment

        import pyspiel

        self._num_players = 2
        game = pyspiel.load_game_as_turn_based(
            env_name, {"players": pyspiel.GameParameter(self._num_players)})
        self._env = rl_environment.Environment(game)

        # Each worker gets access to the shared noise table
        # with independent random streams for sampling
        # from the shared noise table.
        self.deltas = SharedNoiseTable(deltas, env_seed + 7)

        self._policies = [[] for _ in range(self._num_players)]
        self._slow_oracle_kargs = slow_oracle_kargs
        self._fast_oracle_kargs = fast_oracle_kargs
        self._delta_std = self._fast_oracle_kargs['noise']

        self._sess = tf.get_default_session()
        if self._sess is None:
            self._sess = tf.Session()

        if self._slow_oracle_kargs is not None:
            self._slow_oracle_kargs['session'] = self._sess
コード例 #3
0
def main(argv):
  del argv

  game = pyspiel.load_game(FLAGS.game)
  game_type = game.get_type()

  if game_type.dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS:
    logging.warn("%s is not turn-based. Trying to reload game as turn-based.",
                 FLAGS.game)
    game = pyspiel.load_game_as_turn_based(FLAGS.game)
    game_type = game.get_type()

  if game_type.dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL:
    raise ValueError("Game must be sequential, not {}".format(
        game_type.dynamics))

  if (game_type.utility == pyspiel.GameType.Utility.ZERO_SUM and
      game.num_players() == 2):
    logging.info("Game is zero-sum: only showing first-player's returns.")
    gametree = treeviz.GameTree(
        game,
        node_decorator=_zero_sum_node_decorator,
        group_infosets=FLAGS.group_infosets,
        group_terminal=FLAGS.group_terminal)
  else:
    gametree = treeviz.GameTree(game)  # use default decorators

  if FLAGS.verbose:
    logging.info("Game tree:\n%s", gametree.to_string())

  gametree.draw(FLAGS.out, prog=FLAGS.prog)
  logging.info("Game tree saved to file: %s", FLAGS.out)
コード例 #4
0
def main(_):
    game = pyspiel.load_game_as_turn_based(game_, )
    cfr_solver = cfr.CFRSolver(game)

    print("policy_initial:",
          cfr_solver.current_policy().action_probability_array)
    for i in range(FLAGS.iterations):

        if i % FLAGS.print_freq == 0:
            conv = exploitability.exploitability(game,
                                                 cfr_solver.average_policy())
            print("Iteration {} exploitability {}".format(i, conv))
            print("Iteration{}".format(i))
            print("policy_av:",
                  cfr_solver.average_policy().action_probability_array)
            print("policy_cr:",
                  cfr_solver.current_policy().action_probability_array)

        cfr_solver.evaluate_and_update_policy()
        write_csv(dir_ + game_ + "_" + algo_name + "_av.csv",
                  cfr_solver.average_policy().action_probability_array[0])
        write_csv(dir_ + game_ + "_" + algo_name + "_av.csv",
                  cfr_solver.average_policy().action_probability_array[1])
        write_csv(dir_ + game_ + "_" + algo_name + "_cr.csv",
                  cfr_solver.current_policy().action_probability_array[0])
        write_csv(dir_ + game_ + "_" + algo_name + "_cr.csv",
                  cfr_solver.current_policy().action_probability_array[1])
コード例 #5
0
 def test_extensive_to_tensor_game_payoff_tensor(self):
     turn_based_game = pyspiel.load_game_as_turn_based(
         "blotto(players=3,coins=5)")
     tensor_game1 = pyspiel.extensive_to_tensor_game(turn_based_game)
     tensor_game2 = pyspiel.load_tensor_game("blotto(players=3,coins=5)")
     self.assertEqual(tensor_game1.shape(), tensor_game2.shape())
     s0 = turn_based_game.new_initial_state()
     self.assertEqual(tensor_game1.shape()[0], s0.num_distinct_actions())
     for a0 in range(s0.num_distinct_actions()):
         s1 = s0.child(a0)
         self.assertEqual(tensor_game1.shape()[1],
                          s1.num_distinct_actions())
         for a1 in range(s1.num_distinct_actions()):
             s2 = s1.child(a1)
             self.assertEqual(tensor_game1.shape()[2],
                              s2.num_distinct_actions())
             for a2 in range(s2.num_distinct_actions()):
                 s3 = s2.child(a2)
                 self.assertTrue(s3.is_terminal())
                 for player in range(3):
                     self.assertEqual(
                         s3.returns()[player],
                         tensor_game1.player_utility(player, (a0, a1, a2)))
                     self.assertEqual(
                         s3.returns()[player],
                         tensor_game2.player_utility(player, (a0, a1, a2)))
コード例 #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    if FLAGS.seed is None:
        seed = np.random.randint(low=0, high=1e5)
    else:
        seed = FLAGS.seed
    np.random.seed(seed)
    random.seed(seed)
    tf.set_random_seed(seed)
    game = pyspiel.load_game_as_turn_based(
        FLAGS.game_name, {"players": pyspiel.GameParameter(FLAGS.n_players)})
    env = rl_environment.Environment(game, seed=seed)
    env.reset()

    if not os.path.exists(FLAGS.root_result_folder):
        os.makedirs(FLAGS.root_result_folder)
    checkpoint_dir = 'tuning_ars' + str(
        FLAGS.iter_stop_dqn) + '_' + FLAGS.game_name + str(
            FLAGS.n_players) + '_sims_' + str(
                FLAGS.sims_per_entry) + '_it_' + str(
                    FLAGS.gpsro_iterations) + '_ep_' + str(
                        FLAGS.number_training_episodes
                    ) + '_or_' + FLAGS.oracle_type + '_arsnd_' + str(
                        FLAGS.num_directions) + '_se_' + str(
                            seed) + '_' + datetime.datetime.now().strftime(
                                '%Y-%m-%d_%H-%M-%S')
    checkpoint_dir = os.path.join(os.getcwd(), FLAGS.root_result_folder,
                                  checkpoint_dir)

    writer = SummaryWriter(logdir=checkpoint_dir + '/log')
    if FLAGS.sbatch_run:
        sys.stdout = open(checkpoint_dir + '/stdout.txt', 'w+')

    # Initialize oracle and agents
    with tf.Session() as sess:
        if FLAGS.oracle_type == "DQN":
            oracle, agents = init_dqn_responder(sess, env)
        elif FLAGS.oracle_type == "PG":
            oracle, agents = init_pg_responder(sess, env)
        elif FLAGS.oracle_type == "BR":
            oracle, agents = init_br_responder(env)
        elif FLAGS.oracle_type == "ARS":
            oracle, agents = init_ars_responder(sess, env)
        elif FLAGS.oracle_type == "ARS_parallel":
            oracle, agents = init_ars_parallel_responder(sess, env, None)
        sess.run(tf.global_variables_initializer())

        gpsro_looper(env,
                     oracle,
                     agents,
                     writer,
                     quiesce=FLAGS.quiesce,
                     checkpoint_dir=checkpoint_dir,
                     seed=seed,
                     dqn_iters=FLAGS.iter_stop_dqn)

    writer.close()
コード例 #7
0
 def test_shapleys_game(self):
   game = pyspiel.load_game_as_turn_based("matrix_shapleys_game")
   xfp_solver = fictitious_play.XFPSolver(game)
   for i in range(1000):
     xfp_solver.iteration()
     if i % 10 == 0:
       conv = exploitability.nash_conv(game, xfp_solver.average_policy())
       print("FP in Shapley's Game. Iter: {}, NashConv: {}".format(i, conv))
コード例 #8
0
 def test_matching_pennies_3p(self):
   game = pyspiel.load_game_as_turn_based("matching_pennies_3p")
   xfp_solver = fictitious_play.XFPSolver(game)
   for i in range(1000):
     xfp_solver.iteration()
     if i % 10 == 0:
       conv = exploitability.nash_conv(game, xfp_solver.average_policy())
       print("FP in Matching Pennies 3p. Iter: {}, NashConv: {}".format(
           i, conv))
コード例 #9
0
 def test_iigoofspiel4(self):
   game = pyspiel.load_game_as_turn_based("goofspiel", {
       "imp_info": True,
       "num_cards": 4,
       "points_order": "descending",
   })
   val1, val2, _, _ = sequence_form_lp.solve_zero_sum_game(game)
   # symmetric game, should be 0
   self.assertAlmostEqual(val1, 0)
   self.assertAlmostEqual(val2, 0)
コード例 #10
0
ファイル: policy_test.py プロジェクト: ngrupen/open_spiel
 def test_states_lookup(self):
   # Test that there are two valid states, indexed as 0 and 1.
   game = pyspiel.load_game_as_turn_based("matrix_rps")
   state = game.new_initial_state()
   first_info_state = state.information_state_string()
   state.apply_action(state.legal_actions()[0])
   second_info_state = state.information_state_string()
   self.assertCountEqual(self.tabular_policy.state_lookup,
                         [first_info_state, second_info_state])
   self.assertCountEqual(self.tabular_policy.state_lookup.values(), [0, 1])
コード例 #11
0
    def test_extensive_to_matrix_game_payoff_matrix(self):
        turn_based_game = pyspiel.load_game_as_turn_based("matrix_pd")
        matrix_game = pyspiel.extensive_to_matrix_game(turn_based_game)
        orig_game = pyspiel.load_matrix_game("matrix_pd")

        for row in range(orig_game.num_rows()):
            for col in range(orig_game.num_cols()):
                for player in range(2):
                    self.assertEqual(
                        orig_game.player_utility(player, row, col),
                        matrix_game.player_utility(player, row, col))
コード例 #12
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    np.random.seed(FLAGS.seed)

    game = pyspiel.load_game_as_turn_based(
        FLAGS.game_name, {"players": pyspiel.GameParameter(FLAGS.n_players)})
    env = rl_environment.Environment(game)

    # Initialize oracle and agents
    sess = None
    oracle, agents = init_ars_responder(sess, env)
    gpsro_looper(env, oracle, agents)
コード例 #13
0
    def test_sampling_in_simple_games(self):
        matrix_mp_num_states = 1 + 2 + 4
        game = pyspiel.load_game_as_turn_based("matrix_mp")
        for n in range(1, matrix_mp_num_states + 1):
            states = sample_some_states.sample_some_states(game, max_states=n)
            self.assertLen(states, n)

        states = sample_some_states.sample_some_states(game, max_states=1)
        self.assertLen(states, 1)

        states = sample_some_states.sample_some_states(
            game, max_states=matrix_mp_num_states + 1)
        self.assertLen(states, matrix_mp_num_states)

        coordinated_mp_num_states = 1 + 2 + 4 + 8
        game = pyspiel.load_game_as_turn_based("coordinated_mp")
        for n in range(1, coordinated_mp_num_states + 1):
            states = sample_some_states.sample_some_states(game, max_states=n)
            self.assertLen(states, n)

        states = sample_some_states.sample_some_states(
            game, max_states=coordinated_mp_num_states + 1)
        self.assertLen(states, coordinated_mp_num_states)
コード例 #14
0
ファイル: gambit_example.py プロジェクト: ngrupen/open_spiel
def main(argv):
  del argv

  game = pyspiel.load_game(FLAGS.game)
  game_type = game.get_type()

  if game_type.dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS:
    logging.warn("%s is not turn-based. Trying to reload game as turn-based.",
                 FLAGS.game)
    game = pyspiel.load_game_as_turn_based(FLAGS.game)

  gametree = export_gambit(game)  # use default decorators
  if FLAGS.print:
    print(gametree)
  else:
    with open(FLAGS.out, "w") as f:
      f.write(gametree)
    logging.info("Game tree for %s saved to file: %s", FLAGS.game, FLAGS.out)
コード例 #15
0
 def test_matching_pennies_3p(self):
     game = pyspiel.load_game_as_turn_based('matching_pennies_3p')
     deep_cfr_solver = deep_cfr.DeepCFRSolver(game,
                                              policy_network_layers=(16, 8),
                                              advantage_network_layers=(32,
                                                                        16),
                                              num_iterations=2,
                                              num_traversals=2,
                                              learning_rate=1e-3,
                                              batch_size_advantage=None,
                                              batch_size_strategy=None,
                                              memory_capacity=1e7)
     deep_cfr_solver.solve()
     conv = pyspiel.nash_conv(
         game,
         policy.python_policy_to_pyspiel_policy(
             policy.tabular_policy_from_callable(
                 game, deep_cfr_solver.action_probabilities)))
     logging.info('Deep CFR in Matching Pennies 3p. NashConv: %.2f', conv)
コード例 #16
0
 def test_matching_pennies_3p(self):
     # We don't expect Deep CFR to necessarily converge on 3-player games but
     # it's nonetheless interesting to see this result.
     game = pyspiel.load_game_as_turn_based('matching_pennies_3p')
     deep_cfr_solver = deep_cfr.DeepCFRSolver(game,
                                              policy_network_layers=(16, 8),
                                              advantage_network_layers=(32,
                                                                        16),
                                              num_iterations=2,
                                              num_traversals=2,
                                              learning_rate=1e-3,
                                              batch_size_advantage=8,
                                              batch_size_strategy=8,
                                              memory_capacity=1e7)
     deep_cfr_solver.solve()
     conv = exploitability.nash_conv(
         game,
         policy.tabular_policy_from_callable(
             game, deep_cfr_solver.action_probabilities))
     print('Deep CFR in Matching Pennies 3p. NashConv: {}'.format(conv))
コード例 #17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    np.random.seed(FLAGS.seed)

    game = pyspiel.load_game_as_turn_based(
        FLAGS.game_name, {"players": pyspiel.GameParameter(FLAGS.n_players)})
    env = rl_environment.Environment(game)

    # Initialize oracle and agents
    with tf.Session() as sess:
        if FLAGS.oracle_type == "DQN":
            oracle, agents = init_dqn_responder(sess, env)
        elif FLAGS.oracle_type == "PG":
            oracle, agents = init_pg_responder(sess, env)
        elif FLAGS.oracle_type == "BR":
            oracle, agents = init_br_responder(env)
        sess.run(tf.global_variables_initializer())
        gpsro_looper(env, oracle, agents)
コード例 #18
0
def get_game(game_name):
    """Returns the game."""
    if game_name == "kuhn_poker_3p":
        game_name = "kuhn_poker"
        game_kwargs = {"players": int(3)}
    elif game_name == "trade_comm_2p_2i":
        game_name = "trade_comm"
        game_kwargs = {"num_items": int(2)}
    elif game_name == "sheriff_2p_gabriele":
        game_name = "sheriff"
        game_kwargs = {
            "item_penalty": float(1.0),
            "item_value": float(5.0),
            "max_bribe": int(2),
            "max_items": int(10),
            "num_rounds": int(2),
            "sheriff_penalty": float(1.0),
        }

    else:
        raise ValueError("Unrecognised game: %s" % game_name)
    return pyspiel.load_game_as_turn_based(game_name, game_kwargs)
コード例 #19
0
 def setUpClass(cls):
     super(TabularRockPaperScissorsPolicyTest, cls).setUpClass()
     game = pyspiel.load_game_as_turn_based("matrix_rps")
     cls.tabular_policy = policy.TabularPolicy(game)
コード例 #20
0
 def test_rock_paper_scissors(self):
     game = pyspiel.load_game_as_turn_based("matrix_rps")
     val1, val2, _, _ = sequence_form_lp.solve_zero_sum_game(game)
     self.assertAlmostEqual(val1, 0)
     self.assertAlmostEqual(val2, 0)
コード例 #21
0
 def test_simultaneous_game_as_turn_based(self, game_info):
   converted_game = pyspiel.load_game_as_turn_based(game_info.short_name)
   self.sim_game(converted_game)
コード例 #22
0
  def test_observations_are_consistent_with_info_states(self, game_name):
    print(f"Testing observation <-> info_state consistency for '{game_name}'")
    game = pyspiel.load_game(game_name)
    game_type = game.get_type()

    if not game_type.provides_information_state_string \
      or not game_type.provides_observation_string:
      print(f"Skipping test for '{game_name}', as it doesn't provide both "
            "information_state_string and observation_string")
      return

    if game_type.dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS:
      logging.warning(
          "'%s' is not turn-based. Trying to reload game as turn-based.",
          game_name)
      game = pyspiel.load_game_as_turn_based(game_name)

    # Idea of the test: make rollouts in the game, and collect both
    # Action-Observation histories (AOH) and InformationState for different
    # ground states. Check that there is a unique bijection between them.
    #
    # Of course, this test does not exclude the possibility the game might
    # have a bug! But it is a fast way to discover a possible inconsistency
    # in a new implementation.
    aoh_is = dict()  # aoh -> info_state
    is_aoh = dict()  # info_state -> aoh
    aoh_histories = collections.defaultdict(set)  # aoh -> states
    is_histories = collections.defaultdict(set)  # info_states -> states

    # Some games have very long play-throughs.
    give_up_after = 100  # actions

    # Show a helpful error message for debugging the observations in a game.
    def show_error(histories, player, dump_collections=True):
      aohs = list()
      info_states = list()
      descriptions = list()
      # Emulate the histories to collect relevant lists.
      for history in histories:
        state = game.new_initial_state()
        aoh = [("obs", state.observation_string(player))]
        for action in history:
          state.apply_action(action)
          if state.current_player() == player:
            aoh.append(("action", action))
          aoh.append(("obs", state.observation_string(player)))
        aohs.append(aoh)
        info_states.append(state.information_state_string(player))
        descriptions.append(str(state))

      histories_str = "\n".join([str(history) for history in histories])
      descriptions_str = "\n".join(descriptions)
      aohs_str = "\n".join([str(aoh) for aoh in aohs])
      info_states_str = "\n".join([str(s) for s in info_states])

      if dump_collections:
        def format_dump(xs):
          return "\n".join([f"{str(key)}  ->  {str(value)}"
                            for key, value in xs.items()])
        # pylint: disable=g-backslash-continuation
        extras = "Dumping colections:\n" \
                 f"aoh -> info_state:\n{format_dump(aoh_is)}\n\n" \
                 f"info_state -> aoh:\n{format_dump(is_aoh)}\n\n" \
                 f"aoh -> histories:\n{format_dump(aoh_histories)}\n\n" \
                 f"info_state -> histories:\n{format_dump(is_histories)}\n\n"
      else:
        # pylint: disable=g-backslash-continuation
        extras = "Rerun this test with dump_collections=True " \
                 "for extra information."

      # pylint: disable=g-backslash-continuation
      msg = \
        f"\n\n" \
        f"The action-observation histories (AOH) are not consistent with " \
        f"information states for player {player}.\n\n" \
        f"The conflicting set of states (histories) is:\n{histories_str}\n\n" \
        f"Their domain-specific descriptions are:\n{descriptions_str}\n\n" \
        f"The corresponding AOH are:\n{aohs_str}\n\n" \
        f"The corresponding info states are:\n{info_states_str}\n\n" \
        f"{extras}\n" \
        f"What to do to fix this? Consult the documentation to " \
        f"State::InformationStateString and State::ObservationString."
      return msg

    def collect_and_test_rollouts(player):
      random.seed(0)
      nonlocal aoh_is, is_aoh, aoh_histories, is_histories
      state = game.new_initial_state()
      aoh = [("obs", state.observation_string(player))]

      # TODO(author13): we want to check terminals for consistency too, but info
      # state string is not defined there and neither are observations by
      # design.
      while not state.is_terminal():
        if len(state.history()) > give_up_after:
          break

        # Do not collect over chance nodes.
        if not state.is_chance_node():
          info_state = state.information_state_string()
          aoh_histories[str(aoh)].add(tuple(state.history()))
          is_histories[info_state].add(tuple(state.history()))

          states = {tuple(state.history())}
          states = states.union(aoh_histories[str(aoh)])
          states = states.union(is_histories[info_state])
          if str(aoh) in aoh_is:
            states = states.union(is_histories[aoh_is[str(aoh)]])
            self.assertEqual(aoh_is[str(aoh)], info_state,
                             show_error(states, player))
          else:
            aoh_is[str(aoh)] = info_state
          if info_state in is_aoh:
            states = states.union(aoh_histories[str(is_aoh[info_state])])
            self.assertEqual(is_aoh[info_state], str(aoh),
                             show_error(states, player))
          else:
            is_aoh[info_state] = str(aoh)

        # Make random actions.
        action = random.choice(state.legal_actions(state.current_player()))
        if state.current_player() == player:
          aoh.append(("action", action))
        state.apply_action(action)
        aoh.append(("obs", state.observation_string(player)))

    # Run (very roughly!) for this many seconds. This very much depends on the
    # machine the test runs on, as some games take a long time to produce
    # a single rollout.
    time_limit = TIMEABLE_TEST_RUNTIME / game.num_players()
    start = time.time()
    is_time_out = lambda: time.time() - start > time_limit

    rollouts = 0
    for player in range(game.num_players()):
      aoh_is.clear()
      is_aoh.clear()
      aoh_histories.clear()
      is_histories.clear()
      while not is_time_out():
        collect_and_test_rollouts(player)
        rollouts += 1

    print(f"Test for {game_name} took {time.time()-start} seconds "
          f"to make {rollouts} rollouts.")
コード例 #23
0
def do_something(game_name):
    game = pyspiel.load_game_as_turn_based(
        game_name, {"players": pyspiel.GameParameter(2)})
    env = rl_environment.Environment(game)
    return env.name
コード例 #24
0
from open_spiel.python import rl_environment
# from open_spiel.python.algorithms.psro_v2 import rl_oracle
from open_spiel.python.algorithms.psro_v2 import rl_policy

import ray
import cloudpickle

from open_spiel.python.algorithms.psro_v2.ars_ray.workers import worker
from open_spiel.python.algorithms.psro_v2.parallel.worker import do_something
import concurrent.futures

# redis_password = sys.argv[1]
# num_cpus = int(sys.argv[2])

game = pyspiel.load_game_as_turn_based("kuhn_poker",
                                       {"players": pyspiel.GameParameter(2)})
env = rl_environment.Environment(game)
env.reset()

sess = None

info_state_size = env.observation_spec()["info_state"][0]
num_actions = env.action_spec()["num_actions"]
# print(info_state_size, num_actions)
agent_class = rl_policy.ARSPolicy_parallel
agent_kwargs = {
    "session": sess,
    "info_state_size": info_state_size,
    "num_actions": num_actions,
    "learning_rate": 0.03,
    "nb_directions": 32,
コード例 #25
0
def get_game(game_name):
  """Returns the game."""

  if game_name == "kuhn_poker_2p":
    game_name = "kuhn_poker"
    game_kwargs = {"players": int(2)}
  elif game_name == "kuhn_poker_3p":
    game_name = "kuhn_poker"
    game_kwargs = {"players": int(3)}
  elif game_name == "kuhn_poker_4p":
    game_name = "kuhn_poker"
    game_kwargs = {"players": int(4)}

  elif game_name == "leduc_poker_2p":
    game_name = "leduc_poker"
    game_kwargs = {"players": int(2)}
  elif game_name == "leduc_poker_3p":
    game_name = "leduc_poker"
    game_kwargs = {"players": int(3)}
  elif game_name == "leduc_poker_4p":
    game_name = "leduc_poker"
    game_kwargs = {"players": int(4)}

  elif game_name == "trade_comm_2p_2i":
    game_name = "trade_comm"
    game_kwargs = {"num_items": int(2)}
  elif game_name == "trade_comm_2p_3i":
    game_name = "trade_comm"
    game_kwargs = {"num_items": int(3)}
  elif game_name == "trade_comm_2p_4i":
    game_name = "trade_comm"
    game_kwargs = {"num_items": int(4)}
  elif game_name == "trade_comm_2p_5i":
    game_name = "trade_comm"
    game_kwargs = {"num_items": int(5)}

  elif game_name == "tiny_bridge_2p":
    game_name = "tiny_bridge_2p"
    game_kwargs = {}
  elif game_name == "tiny_bridge_4p":
    game_name = "tiny_bridge_4p"
    game_kwargs = {}  # Too big game.

  elif game_name == "sheriff_2p_1r":
    game_name = "sheriff"
    game_kwargs = {"num_rounds": int(1)}
  elif game_name == "sheriff_2p_2r":
    game_name = "sheriff"
    game_kwargs = {"num_rounds": int(2)}
  elif game_name == "sheriff_2p_3r":
    game_name = "sheriff"
    game_kwargs = {"num_rounds": int(3)}
  elif game_name == "sheriff_2p_gabriele":
    game_name = "sheriff"
    game_kwargs = {
        "item_penalty": float(1.0),
        "item_value": float(5.0),
        "max_bribe": int(2),
        "max_items": int(10),
        "num_rounds": int(2),
        "sheriff_penalty": float(1.0),
    }

  elif game_name == "goofspiel_2p_3c_total":
    game_name = "goofspiel"
    game_kwargs = {
        "players": int(2),
        "returns_type": "total_points",
        "num_cards": int(3)}
  elif game_name == "goofspiel_2p_4c_total":
    game_name = "goofspiel"
    game_kwargs = {
        "players": int(2),
        "returns_type": "total_points",
        "num_cards": int(4)}
  elif game_name == "goofspiel_2p_5c_total":
    game_name = "goofspiel"
    game_kwargs = {
        "players": int(2),
        "returns_type": "total_points",
        "num_cards": int(5)}

  else:
    raise ValueError("Unrecognised game: %s" % game_name)

  return pyspiel.load_game_as_turn_based(game_name, game_kwargs)
コード例 #26
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")
 
  if FLAGS.seed is None:
    seed = np.random.randint(low=0, high=1e5)
  else:
    seed = FLAGS.seed
  np.random.seed(seed)
  random.seed(seed)
  tf.set_random_seed(seed)
  game = pyspiel.load_game_as_turn_based(FLAGS.game_name,
                                         {"players": pyspiel.GameParameter(
                                             FLAGS.n_players)})
  env = rl_environment.Environment(game,seed=seed)
  env.reset()

  if FLAGS.heuristic_list:
    heuristic_list = FLAGS.heuristic_list
    if '_strategy' in heuristic_list[0]:
      FLAGS.meta_strategy_method = heuristic_list[0][:heuristic_list[0].index('_strategy')]
    else:
      FLAGS.meta_strategy_method = heuristic_list[0]
  else:
    heuristic_list = ["general_nash_strategy", "uniform_strategy"] 
  
  if 'sp' in FLAGS.heuristic_to_add:
    heuristic_list.append("self_play_strategy")
  if 'weighted_ne' in FLAGS.heuristic_to_add:
    heuristic_list.append("weighted_NE_strategy")
  if 'prd' in FLAGS.heuristic_to_add:
    heuristic_list.append("prd_strategy")
  
  if not os.path.exists(FLAGS.root_result_folder):
    os.makedirs(FLAGS.root_result_folder)
  
  checkpoint_dir = 'se_'+FLAGS.game_name+str(FLAGS.n_players)+'_sims_'+str(FLAGS.sims_per_entry)+'_it_'+str(FLAGS.gpsro_iterations)+'_ep_'+str(FLAGS.number_training_episodes)+'_or_'+FLAGS.oracle_type

  checkpoint_dir += '_msl_'+",".join(heuristic_list)

  if FLAGS.switch_fast_slow:
    checkpoint_dir += '_sfs_'+'_fp_'+str(FLAGS.fast_oracle_period)+'_sp_'+str(FLAGS.slow_oracle_period) + '_arslr_'+str(FLAGS.ars_learning_rate)+'_arsn_'+str(FLAGS.noise)+'_arsnd_'+str(FLAGS.num_directions)+'_arsbd_'+str(FLAGS.num_best_directions)+'_epars_'+str(FLAGS.number_training_episodes_ars)
  elif FLAGS.switch_heuristic_regardless_of_oracle:
    checkpoint_dir += '_switch_heuristics_'

  if FLAGS.oracle_type == 'BR':
    oracle_flag_str = ''
  else:
    oracle_flag_str = '_hl_'+str(FLAGS.hidden_layer_size)+'_bs_'+str(FLAGS.batch_size)+'_nhl_'+str(FLAGS.n_hidden_layers)
    if FLAGS.oracle_type == 'DQN':
      oracle_flag_str += '_dqnlr_'+str(FLAGS.dqn_learning_rate)+'_tnuf_'+str(FLAGS.update_target_network_every)+'_lf_'+str(FLAGS.learn_every)
    else:
      oracle_flag_str += '_ls_'+str(FLAGS.loss_str)+'_nqbp_'+str(FLAGS.num_q_before_pi)+'_ec_'+str(FLAGS.entropy_cost)+'_clr_'+str(FLAGS.critic_learning_rate)+'_pilr_'+str(FLAGS.pi_learning_rate)
  checkpoint_dir = checkpoint_dir + oracle_flag_str+'_se_'+str(seed)+'_'+datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  checkpoint_dir = os.path.join(os.getcwd(),FLAGS.root_result_folder, checkpoint_dir)
                                
  writer = SummaryWriter(logdir=checkpoint_dir+'/log')
  if FLAGS.sbatch_run:
    sys.stdout = open(checkpoint_dir+'/stdout.txt','w+')

  # Initialize oracle and agents
  with tf.Session() as sess:
    if FLAGS.oracle_type == "DQN":
      slow_oracle, agents, agent_kwargs = init_dqn_responder(sess, env)
    elif FLAGS.oracle_type == "PG":
      slow_oracle, agents, agent_kwargs = init_pg_responder(sess, env)
    elif FLAGS.oracle_type == "BR":
      slow_oracle, agents = init_br_responder(env)
      agent_kwargs = None
    elif FLAGS.oracle_type == "ARS":
      slow_oracle, agents = init_ars_responder(sess, env)
      agent_kwargs = None

    sess.run(tf.global_variables_initializer())
    
    if FLAGS.switch_fast_slow:
      fast_oracle, agents = init_ars_responder(sess=None, env=env)
      oracle_list = [[], []]
      oracle_list[0].append(slow_oracle)
      oracle_list[0].append(fast_oracle)
      oracle_list[1] = [FLAGS.oracle_type,'ARS']
    else:
      oracle_list = None

    gpsro_looper(env,
                 slow_oracle,
                 oracle_list,
                 agents,
                 writer,
                 quiesce=FLAGS.quiesce,
                 checkpoint_dir=checkpoint_dir,
                 seed=seed,
                 heuristic_list=heuristic_list)


  writer.close()
コード例 #27
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")
 
  if FLAGS.seed is None:
    seed = np.random.randint(low=0, high=1e5)
  else:
    seed = FLAGS.seed
  np.random.seed(seed)
  random.seed(seed)

  tf.set_random_seed(seed)

  game_param = {"players": pyspiel.GameParameter(FLAGS.n_players)}
  checkpoint_dir = FLAGS.game_name
  if FLAGS.game_param is not None:
    for ele in FLAGS.game_param:
      ele_li = ele.split("=")
      game_param[ele_li[0]] = pyspiel.GameParameter(int(ele_li[1]))
      checkpoint_dir += '_'+ele_li[0]+'_'+ele_li[1]
    checkpoint_dir += '_'
  game = pyspiel.load_game_as_turn_based(FLAGS.game_name, game_param)

  env = rl_environment.Environment(game,seed=seed)
  env.reset()

  if not os.path.exists(FLAGS.root_result_folder):
    os.makedirs(FLAGS.root_result_folder)
  checkpoint_dir += str(FLAGS.n_players)+'_sims_'+str(FLAGS.sims_per_entry)+'_it_'+str(FLAGS.gpsro_iterations)+'_ep_'+str(FLAGS.number_training_episodes)+'_or_'+FLAGS.oracle_type+'_heur_'+FLAGS.meta_strategy_method
  if FLAGS.oracle_type == 'ARS':
    oracle_flag_str = '_arslr_'+str(FLAGS.ars_learning_rate)+'_arsn_'+str(FLAGS.noise)+'_arsnd_'+str(FLAGS.num_directions)+'_arsbd_'+str(FLAGS.num_best_directions)
  elif FLAGS.oracle_type == 'BR':
    oracle_flag_str = ''
  else:
    oracle_flag_str = '_hl_'+str(FLAGS.hidden_layer_size)+'_bs_'+str(FLAGS.batch_size)+'_nhl_'+str(FLAGS.n_hidden_layers)
    if FLAGS.oracle_type == 'DQN':
      oracle_flag_str += '_dqnlr_'+str(FLAGS.dqn_learning_rate)+'_tnuf_'+str(FLAGS.update_target_network_every)+'_lf_'+str(FLAGS.learn_every)
    else:
      oracle_flag_str += '_ls_'+str(FLAGS.loss_str)+'_nqbp_'+str(FLAGS.num_q_before_pi)+'_ec_'+str(FLAGS.entropy_cost)+'_clr_'+str(FLAGS.critic_learning_rate)+'_pilr_'+str(FLAGS.pi_learning_rate)

  checkpoint_dir = checkpoint_dir + oracle_flag_str+'_se_'+str(seed)+'_'+datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  checkpoint_dir = os.path.join(os.getcwd(),FLAGS.root_result_folder, checkpoint_dir)
                                
  writer = SummaryWriter(logdir=checkpoint_dir+'/log')
  if FLAGS.sbatch_run:
    sys.stdout = open(checkpoint_dir+'/stdout.txt','w+')

  # Initialize oracle and agents
  with tf.Session() as sess:
    if FLAGS.oracle_type == "DQN":
      oracle, agents = init_dqn_responder(sess, env)
    elif FLAGS.oracle_type == "PG":
      oracle, agents = init_pg_responder(sess, env)
    elif FLAGS.oracle_type == "BR":
      oracle, agents = init_br_responder(env)
    elif FLAGS.oracle_type == "ARS":
      oracle, agents = init_ars_responder(sess, env)
    elif FLAGS.oracle_type == "ARS_parallel":
      oracle, agents = init_ars_parallel_responder(sess, env)
    # sess.run(tf.global_variables_initializer())
    gpsro_looper(env, oracle, agents, writer, quiesce=FLAGS.quiesce, checkpoint_dir=checkpoint_dir, seed=seed)

  writer.close()