def test_sampling_in_simple_games(self): # The tests that have explicit time limit # can fail with very small probability. matrix_mp_num_states = 1 + 2 + 4 game = pyspiel.load_game_as_turn_based("matrix_mp") for n in range(1, matrix_mp_num_states + 1): states = sample_some_states.sample_some_states(game, max_states=n) self.assertLen(states, n) states = sample_some_states.sample_some_states(game, max_states=1, depth_limit=0) self.assertLen(states, 1) states = sample_some_states.sample_some_states( game, max_states=matrix_mp_num_states + 1, time_limit=0.1) self.assertLen(states, matrix_mp_num_states) states = sample_some_states.sample_some_states( game, include_terminals=False, time_limit=0.1, max_states=matrix_mp_num_states) self.assertLen(states, 3) states = sample_some_states.sample_some_states( game, depth_limit=1, time_limit=0.1, max_states=matrix_mp_num_states) self.assertLen(states, 3) coordinated_mp_num_states = 1 + 2 + 4 + 8 game = pyspiel.load_game_as_turn_based("coordinated_mp") for n in range(1, coordinated_mp_num_states + 1): states = sample_some_states.sample_some_states(game, max_states=n) self.assertLen(states, n) states = sample_some_states.sample_some_states( game, max_states=coordinated_mp_num_states + 1, time_limit=0.1) self.assertLen(states, coordinated_mp_num_states) states = sample_some_states.sample_some_states( game, max_states=coordinated_mp_num_states, include_chance_states=False, time_limit=0.1) self.assertLen(states, coordinated_mp_num_states - 2)
def __init__(self, env_name, env_seed=2, deltas=None, slow_oracle_kargs=None, fast_oracle_kargs=None): # initialize rl environment. from open_spiel.python import rl_environment import pyspiel self._num_players = 2 game = pyspiel.load_game_as_turn_based( env_name, {"players": pyspiel.GameParameter(self._num_players)}) self._env = rl_environment.Environment(game) # Each worker gets access to the shared noise table # with independent random streams for sampling # from the shared noise table. self.deltas = SharedNoiseTable(deltas, env_seed + 7) self._policies = [[] for _ in range(self._num_players)] self._slow_oracle_kargs = slow_oracle_kargs self._fast_oracle_kargs = fast_oracle_kargs self._delta_std = self._fast_oracle_kargs['noise'] self._sess = tf.get_default_session() if self._sess is None: self._sess = tf.Session() if self._slow_oracle_kargs is not None: self._slow_oracle_kargs['session'] = self._sess
def main(argv): del argv game = pyspiel.load_game(FLAGS.game) game_type = game.get_type() if game_type.dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS: logging.warn("%s is not turn-based. Trying to reload game as turn-based.", FLAGS.game) game = pyspiel.load_game_as_turn_based(FLAGS.game) game_type = game.get_type() if game_type.dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL: raise ValueError("Game must be sequential, not {}".format( game_type.dynamics)) if (game_type.utility == pyspiel.GameType.Utility.ZERO_SUM and game.num_players() == 2): logging.info("Game is zero-sum: only showing first-player's returns.") gametree = treeviz.GameTree( game, node_decorator=_zero_sum_node_decorator, group_infosets=FLAGS.group_infosets, group_terminal=FLAGS.group_terminal) else: gametree = treeviz.GameTree(game) # use default decorators if FLAGS.verbose: logging.info("Game tree:\n%s", gametree.to_string()) gametree.draw(FLAGS.out, prog=FLAGS.prog) logging.info("Game tree saved to file: %s", FLAGS.out)
def main(_): game = pyspiel.load_game_as_turn_based(game_, ) cfr_solver = cfr.CFRSolver(game) print("policy_initial:", cfr_solver.current_policy().action_probability_array) for i in range(FLAGS.iterations): if i % FLAGS.print_freq == 0: conv = exploitability.exploitability(game, cfr_solver.average_policy()) print("Iteration {} exploitability {}".format(i, conv)) print("Iteration{}".format(i)) print("policy_av:", cfr_solver.average_policy().action_probability_array) print("policy_cr:", cfr_solver.current_policy().action_probability_array) cfr_solver.evaluate_and_update_policy() write_csv(dir_ + game_ + "_" + algo_name + "_av.csv", cfr_solver.average_policy().action_probability_array[0]) write_csv(dir_ + game_ + "_" + algo_name + "_av.csv", cfr_solver.average_policy().action_probability_array[1]) write_csv(dir_ + game_ + "_" + algo_name + "_cr.csv", cfr_solver.current_policy().action_probability_array[0]) write_csv(dir_ + game_ + "_" + algo_name + "_cr.csv", cfr_solver.current_policy().action_probability_array[1])
def test_extensive_to_tensor_game_payoff_tensor(self): turn_based_game = pyspiel.load_game_as_turn_based( "blotto(players=3,coins=5)") tensor_game1 = pyspiel.extensive_to_tensor_game(turn_based_game) tensor_game2 = pyspiel.load_tensor_game("blotto(players=3,coins=5)") self.assertEqual(tensor_game1.shape(), tensor_game2.shape()) s0 = turn_based_game.new_initial_state() self.assertEqual(tensor_game1.shape()[0], s0.num_distinct_actions()) for a0 in range(s0.num_distinct_actions()): s1 = s0.child(a0) self.assertEqual(tensor_game1.shape()[1], s1.num_distinct_actions()) for a1 in range(s1.num_distinct_actions()): s2 = s1.child(a1) self.assertEqual(tensor_game1.shape()[2], s2.num_distinct_actions()) for a2 in range(s2.num_distinct_actions()): s3 = s2.child(a2) self.assertTrue(s3.is_terminal()) for player in range(3): self.assertEqual( s3.returns()[player], tensor_game1.player_utility(player, (a0, a1, a2))) self.assertEqual( s3.returns()[player], tensor_game2.player_utility(player, (a0, a1, a2)))
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") if FLAGS.seed is None: seed = np.random.randint(low=0, high=1e5) else: seed = FLAGS.seed np.random.seed(seed) random.seed(seed) tf.set_random_seed(seed) game = pyspiel.load_game_as_turn_based( FLAGS.game_name, {"players": pyspiel.GameParameter(FLAGS.n_players)}) env = rl_environment.Environment(game, seed=seed) env.reset() if not os.path.exists(FLAGS.root_result_folder): os.makedirs(FLAGS.root_result_folder) checkpoint_dir = 'tuning_ars' + str( FLAGS.iter_stop_dqn) + '_' + FLAGS.game_name + str( FLAGS.n_players) + '_sims_' + str( FLAGS.sims_per_entry) + '_it_' + str( FLAGS.gpsro_iterations) + '_ep_' + str( FLAGS.number_training_episodes ) + '_or_' + FLAGS.oracle_type + '_arsnd_' + str( FLAGS.num_directions) + '_se_' + str( seed) + '_' + datetime.datetime.now().strftime( '%Y-%m-%d_%H-%M-%S') checkpoint_dir = os.path.join(os.getcwd(), FLAGS.root_result_folder, checkpoint_dir) writer = SummaryWriter(logdir=checkpoint_dir + '/log') if FLAGS.sbatch_run: sys.stdout = open(checkpoint_dir + '/stdout.txt', 'w+') # Initialize oracle and agents with tf.Session() as sess: if FLAGS.oracle_type == "DQN": oracle, agents = init_dqn_responder(sess, env) elif FLAGS.oracle_type == "PG": oracle, agents = init_pg_responder(sess, env) elif FLAGS.oracle_type == "BR": oracle, agents = init_br_responder(env) elif FLAGS.oracle_type == "ARS": oracle, agents = init_ars_responder(sess, env) elif FLAGS.oracle_type == "ARS_parallel": oracle, agents = init_ars_parallel_responder(sess, env, None) sess.run(tf.global_variables_initializer()) gpsro_looper(env, oracle, agents, writer, quiesce=FLAGS.quiesce, checkpoint_dir=checkpoint_dir, seed=seed, dqn_iters=FLAGS.iter_stop_dqn) writer.close()
def test_shapleys_game(self): game = pyspiel.load_game_as_turn_based("matrix_shapleys_game") xfp_solver = fictitious_play.XFPSolver(game) for i in range(1000): xfp_solver.iteration() if i % 10 == 0: conv = exploitability.nash_conv(game, xfp_solver.average_policy()) print("FP in Shapley's Game. Iter: {}, NashConv: {}".format(i, conv))
def test_matching_pennies_3p(self): game = pyspiel.load_game_as_turn_based("matching_pennies_3p") xfp_solver = fictitious_play.XFPSolver(game) for i in range(1000): xfp_solver.iteration() if i % 10 == 0: conv = exploitability.nash_conv(game, xfp_solver.average_policy()) print("FP in Matching Pennies 3p. Iter: {}, NashConv: {}".format( i, conv))
def test_iigoofspiel4(self): game = pyspiel.load_game_as_turn_based("goofspiel", { "imp_info": True, "num_cards": 4, "points_order": "descending", }) val1, val2, _, _ = sequence_form_lp.solve_zero_sum_game(game) # symmetric game, should be 0 self.assertAlmostEqual(val1, 0) self.assertAlmostEqual(val2, 0)
def test_states_lookup(self): # Test that there are two valid states, indexed as 0 and 1. game = pyspiel.load_game_as_turn_based("matrix_rps") state = game.new_initial_state() first_info_state = state.information_state_string() state.apply_action(state.legal_actions()[0]) second_info_state = state.information_state_string() self.assertCountEqual(self.tabular_policy.state_lookup, [first_info_state, second_info_state]) self.assertCountEqual(self.tabular_policy.state_lookup.values(), [0, 1])
def test_extensive_to_matrix_game_payoff_matrix(self): turn_based_game = pyspiel.load_game_as_turn_based("matrix_pd") matrix_game = pyspiel.extensive_to_matrix_game(turn_based_game) orig_game = pyspiel.load_matrix_game("matrix_pd") for row in range(orig_game.num_rows()): for col in range(orig_game.num_cols()): for player in range(2): self.assertEqual( orig_game.player_utility(player, row, col), matrix_game.player_utility(player, row, col))
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") np.random.seed(FLAGS.seed) game = pyspiel.load_game_as_turn_based( FLAGS.game_name, {"players": pyspiel.GameParameter(FLAGS.n_players)}) env = rl_environment.Environment(game) # Initialize oracle and agents sess = None oracle, agents = init_ars_responder(sess, env) gpsro_looper(env, oracle, agents)
def test_sampling_in_simple_games(self): matrix_mp_num_states = 1 + 2 + 4 game = pyspiel.load_game_as_turn_based("matrix_mp") for n in range(1, matrix_mp_num_states + 1): states = sample_some_states.sample_some_states(game, max_states=n) self.assertLen(states, n) states = sample_some_states.sample_some_states(game, max_states=1) self.assertLen(states, 1) states = sample_some_states.sample_some_states( game, max_states=matrix_mp_num_states + 1) self.assertLen(states, matrix_mp_num_states) coordinated_mp_num_states = 1 + 2 + 4 + 8 game = pyspiel.load_game_as_turn_based("coordinated_mp") for n in range(1, coordinated_mp_num_states + 1): states = sample_some_states.sample_some_states(game, max_states=n) self.assertLen(states, n) states = sample_some_states.sample_some_states( game, max_states=coordinated_mp_num_states + 1) self.assertLen(states, coordinated_mp_num_states)
def main(argv): del argv game = pyspiel.load_game(FLAGS.game) game_type = game.get_type() if game_type.dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS: logging.warn("%s is not turn-based. Trying to reload game as turn-based.", FLAGS.game) game = pyspiel.load_game_as_turn_based(FLAGS.game) gametree = export_gambit(game) # use default decorators if FLAGS.print: print(gametree) else: with open(FLAGS.out, "w") as f: f.write(gametree) logging.info("Game tree for %s saved to file: %s", FLAGS.game, FLAGS.out)
def test_matching_pennies_3p(self): game = pyspiel.load_game_as_turn_based('matching_pennies_3p') deep_cfr_solver = deep_cfr.DeepCFRSolver(game, policy_network_layers=(16, 8), advantage_network_layers=(32, 16), num_iterations=2, num_traversals=2, learning_rate=1e-3, batch_size_advantage=None, batch_size_strategy=None, memory_capacity=1e7) deep_cfr_solver.solve() conv = pyspiel.nash_conv( game, policy.python_policy_to_pyspiel_policy( policy.tabular_policy_from_callable( game, deep_cfr_solver.action_probabilities))) logging.info('Deep CFR in Matching Pennies 3p. NashConv: %.2f', conv)
def test_matching_pennies_3p(self): # We don't expect Deep CFR to necessarily converge on 3-player games but # it's nonetheless interesting to see this result. game = pyspiel.load_game_as_turn_based('matching_pennies_3p') deep_cfr_solver = deep_cfr.DeepCFRSolver(game, policy_network_layers=(16, 8), advantage_network_layers=(32, 16), num_iterations=2, num_traversals=2, learning_rate=1e-3, batch_size_advantage=8, batch_size_strategy=8, memory_capacity=1e7) deep_cfr_solver.solve() conv = exploitability.nash_conv( game, policy.tabular_policy_from_callable( game, deep_cfr_solver.action_probabilities)) print('Deep CFR in Matching Pennies 3p. NashConv: {}'.format(conv))
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") np.random.seed(FLAGS.seed) game = pyspiel.load_game_as_turn_based( FLAGS.game_name, {"players": pyspiel.GameParameter(FLAGS.n_players)}) env = rl_environment.Environment(game) # Initialize oracle and agents with tf.Session() as sess: if FLAGS.oracle_type == "DQN": oracle, agents = init_dqn_responder(sess, env) elif FLAGS.oracle_type == "PG": oracle, agents = init_pg_responder(sess, env) elif FLAGS.oracle_type == "BR": oracle, agents = init_br_responder(env) sess.run(tf.global_variables_initializer()) gpsro_looper(env, oracle, agents)
def get_game(game_name): """Returns the game.""" if game_name == "kuhn_poker_3p": game_name = "kuhn_poker" game_kwargs = {"players": int(3)} elif game_name == "trade_comm_2p_2i": game_name = "trade_comm" game_kwargs = {"num_items": int(2)} elif game_name == "sheriff_2p_gabriele": game_name = "sheriff" game_kwargs = { "item_penalty": float(1.0), "item_value": float(5.0), "max_bribe": int(2), "max_items": int(10), "num_rounds": int(2), "sheriff_penalty": float(1.0), } else: raise ValueError("Unrecognised game: %s" % game_name) return pyspiel.load_game_as_turn_based(game_name, game_kwargs)
def setUpClass(cls): super(TabularRockPaperScissorsPolicyTest, cls).setUpClass() game = pyspiel.load_game_as_turn_based("matrix_rps") cls.tabular_policy = policy.TabularPolicy(game)
def test_rock_paper_scissors(self): game = pyspiel.load_game_as_turn_based("matrix_rps") val1, val2, _, _ = sequence_form_lp.solve_zero_sum_game(game) self.assertAlmostEqual(val1, 0) self.assertAlmostEqual(val2, 0)
def test_simultaneous_game_as_turn_based(self, game_info): converted_game = pyspiel.load_game_as_turn_based(game_info.short_name) self.sim_game(converted_game)
def test_observations_are_consistent_with_info_states(self, game_name): print(f"Testing observation <-> info_state consistency for '{game_name}'") game = pyspiel.load_game(game_name) game_type = game.get_type() if not game_type.provides_information_state_string \ or not game_type.provides_observation_string: print(f"Skipping test for '{game_name}', as it doesn't provide both " "information_state_string and observation_string") return if game_type.dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS: logging.warning( "'%s' is not turn-based. Trying to reload game as turn-based.", game_name) game = pyspiel.load_game_as_turn_based(game_name) # Idea of the test: make rollouts in the game, and collect both # Action-Observation histories (AOH) and InformationState for different # ground states. Check that there is a unique bijection between them. # # Of course, this test does not exclude the possibility the game might # have a bug! But it is a fast way to discover a possible inconsistency # in a new implementation. aoh_is = dict() # aoh -> info_state is_aoh = dict() # info_state -> aoh aoh_histories = collections.defaultdict(set) # aoh -> states is_histories = collections.defaultdict(set) # info_states -> states # Some games have very long play-throughs. give_up_after = 100 # actions # Show a helpful error message for debugging the observations in a game. def show_error(histories, player, dump_collections=True): aohs = list() info_states = list() descriptions = list() # Emulate the histories to collect relevant lists. for history in histories: state = game.new_initial_state() aoh = [("obs", state.observation_string(player))] for action in history: state.apply_action(action) if state.current_player() == player: aoh.append(("action", action)) aoh.append(("obs", state.observation_string(player))) aohs.append(aoh) info_states.append(state.information_state_string(player)) descriptions.append(str(state)) histories_str = "\n".join([str(history) for history in histories]) descriptions_str = "\n".join(descriptions) aohs_str = "\n".join([str(aoh) for aoh in aohs]) info_states_str = "\n".join([str(s) for s in info_states]) if dump_collections: def format_dump(xs): return "\n".join([f"{str(key)} -> {str(value)}" for key, value in xs.items()]) # pylint: disable=g-backslash-continuation extras = "Dumping colections:\n" \ f"aoh -> info_state:\n{format_dump(aoh_is)}\n\n" \ f"info_state -> aoh:\n{format_dump(is_aoh)}\n\n" \ f"aoh -> histories:\n{format_dump(aoh_histories)}\n\n" \ f"info_state -> histories:\n{format_dump(is_histories)}\n\n" else: # pylint: disable=g-backslash-continuation extras = "Rerun this test with dump_collections=True " \ "for extra information." # pylint: disable=g-backslash-continuation msg = \ f"\n\n" \ f"The action-observation histories (AOH) are not consistent with " \ f"information states for player {player}.\n\n" \ f"The conflicting set of states (histories) is:\n{histories_str}\n\n" \ f"Their domain-specific descriptions are:\n{descriptions_str}\n\n" \ f"The corresponding AOH are:\n{aohs_str}\n\n" \ f"The corresponding info states are:\n{info_states_str}\n\n" \ f"{extras}\n" \ f"What to do to fix this? Consult the documentation to " \ f"State::InformationStateString and State::ObservationString." return msg def collect_and_test_rollouts(player): random.seed(0) nonlocal aoh_is, is_aoh, aoh_histories, is_histories state = game.new_initial_state() aoh = [("obs", state.observation_string(player))] # TODO(author13): we want to check terminals for consistency too, but info # state string is not defined there and neither are observations by # design. while not state.is_terminal(): if len(state.history()) > give_up_after: break # Do not collect over chance nodes. if not state.is_chance_node(): info_state = state.information_state_string() aoh_histories[str(aoh)].add(tuple(state.history())) is_histories[info_state].add(tuple(state.history())) states = {tuple(state.history())} states = states.union(aoh_histories[str(aoh)]) states = states.union(is_histories[info_state]) if str(aoh) in aoh_is: states = states.union(is_histories[aoh_is[str(aoh)]]) self.assertEqual(aoh_is[str(aoh)], info_state, show_error(states, player)) else: aoh_is[str(aoh)] = info_state if info_state in is_aoh: states = states.union(aoh_histories[str(is_aoh[info_state])]) self.assertEqual(is_aoh[info_state], str(aoh), show_error(states, player)) else: is_aoh[info_state] = str(aoh) # Make random actions. action = random.choice(state.legal_actions(state.current_player())) if state.current_player() == player: aoh.append(("action", action)) state.apply_action(action) aoh.append(("obs", state.observation_string(player))) # Run (very roughly!) for this many seconds. This very much depends on the # machine the test runs on, as some games take a long time to produce # a single rollout. time_limit = TIMEABLE_TEST_RUNTIME / game.num_players() start = time.time() is_time_out = lambda: time.time() - start > time_limit rollouts = 0 for player in range(game.num_players()): aoh_is.clear() is_aoh.clear() aoh_histories.clear() is_histories.clear() while not is_time_out(): collect_and_test_rollouts(player) rollouts += 1 print(f"Test for {game_name} took {time.time()-start} seconds " f"to make {rollouts} rollouts.")
def do_something(game_name): game = pyspiel.load_game_as_turn_based( game_name, {"players": pyspiel.GameParameter(2)}) env = rl_environment.Environment(game) return env.name
from open_spiel.python import rl_environment # from open_spiel.python.algorithms.psro_v2 import rl_oracle from open_spiel.python.algorithms.psro_v2 import rl_policy import ray import cloudpickle from open_spiel.python.algorithms.psro_v2.ars_ray.workers import worker from open_spiel.python.algorithms.psro_v2.parallel.worker import do_something import concurrent.futures # redis_password = sys.argv[1] # num_cpus = int(sys.argv[2]) game = pyspiel.load_game_as_turn_based("kuhn_poker", {"players": pyspiel.GameParameter(2)}) env = rl_environment.Environment(game) env.reset() sess = None info_state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] # print(info_state_size, num_actions) agent_class = rl_policy.ARSPolicy_parallel agent_kwargs = { "session": sess, "info_state_size": info_state_size, "num_actions": num_actions, "learning_rate": 0.03, "nb_directions": 32,
def get_game(game_name): """Returns the game.""" if game_name == "kuhn_poker_2p": game_name = "kuhn_poker" game_kwargs = {"players": int(2)} elif game_name == "kuhn_poker_3p": game_name = "kuhn_poker" game_kwargs = {"players": int(3)} elif game_name == "kuhn_poker_4p": game_name = "kuhn_poker" game_kwargs = {"players": int(4)} elif game_name == "leduc_poker_2p": game_name = "leduc_poker" game_kwargs = {"players": int(2)} elif game_name == "leduc_poker_3p": game_name = "leduc_poker" game_kwargs = {"players": int(3)} elif game_name == "leduc_poker_4p": game_name = "leduc_poker" game_kwargs = {"players": int(4)} elif game_name == "trade_comm_2p_2i": game_name = "trade_comm" game_kwargs = {"num_items": int(2)} elif game_name == "trade_comm_2p_3i": game_name = "trade_comm" game_kwargs = {"num_items": int(3)} elif game_name == "trade_comm_2p_4i": game_name = "trade_comm" game_kwargs = {"num_items": int(4)} elif game_name == "trade_comm_2p_5i": game_name = "trade_comm" game_kwargs = {"num_items": int(5)} elif game_name == "tiny_bridge_2p": game_name = "tiny_bridge_2p" game_kwargs = {} elif game_name == "tiny_bridge_4p": game_name = "tiny_bridge_4p" game_kwargs = {} # Too big game. elif game_name == "sheriff_2p_1r": game_name = "sheriff" game_kwargs = {"num_rounds": int(1)} elif game_name == "sheriff_2p_2r": game_name = "sheriff" game_kwargs = {"num_rounds": int(2)} elif game_name == "sheriff_2p_3r": game_name = "sheriff" game_kwargs = {"num_rounds": int(3)} elif game_name == "sheriff_2p_gabriele": game_name = "sheriff" game_kwargs = { "item_penalty": float(1.0), "item_value": float(5.0), "max_bribe": int(2), "max_items": int(10), "num_rounds": int(2), "sheriff_penalty": float(1.0), } elif game_name == "goofspiel_2p_3c_total": game_name = "goofspiel" game_kwargs = { "players": int(2), "returns_type": "total_points", "num_cards": int(3)} elif game_name == "goofspiel_2p_4c_total": game_name = "goofspiel" game_kwargs = { "players": int(2), "returns_type": "total_points", "num_cards": int(4)} elif game_name == "goofspiel_2p_5c_total": game_name = "goofspiel" game_kwargs = { "players": int(2), "returns_type": "total_points", "num_cards": int(5)} else: raise ValueError("Unrecognised game: %s" % game_name) return pyspiel.load_game_as_turn_based(game_name, game_kwargs)
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") if FLAGS.seed is None: seed = np.random.randint(low=0, high=1e5) else: seed = FLAGS.seed np.random.seed(seed) random.seed(seed) tf.set_random_seed(seed) game = pyspiel.load_game_as_turn_based(FLAGS.game_name, {"players": pyspiel.GameParameter( FLAGS.n_players)}) env = rl_environment.Environment(game,seed=seed) env.reset() if FLAGS.heuristic_list: heuristic_list = FLAGS.heuristic_list if '_strategy' in heuristic_list[0]: FLAGS.meta_strategy_method = heuristic_list[0][:heuristic_list[0].index('_strategy')] else: FLAGS.meta_strategy_method = heuristic_list[0] else: heuristic_list = ["general_nash_strategy", "uniform_strategy"] if 'sp' in FLAGS.heuristic_to_add: heuristic_list.append("self_play_strategy") if 'weighted_ne' in FLAGS.heuristic_to_add: heuristic_list.append("weighted_NE_strategy") if 'prd' in FLAGS.heuristic_to_add: heuristic_list.append("prd_strategy") if not os.path.exists(FLAGS.root_result_folder): os.makedirs(FLAGS.root_result_folder) checkpoint_dir = 'se_'+FLAGS.game_name+str(FLAGS.n_players)+'_sims_'+str(FLAGS.sims_per_entry)+'_it_'+str(FLAGS.gpsro_iterations)+'_ep_'+str(FLAGS.number_training_episodes)+'_or_'+FLAGS.oracle_type checkpoint_dir += '_msl_'+",".join(heuristic_list) if FLAGS.switch_fast_slow: checkpoint_dir += '_sfs_'+'_fp_'+str(FLAGS.fast_oracle_period)+'_sp_'+str(FLAGS.slow_oracle_period) + '_arslr_'+str(FLAGS.ars_learning_rate)+'_arsn_'+str(FLAGS.noise)+'_arsnd_'+str(FLAGS.num_directions)+'_arsbd_'+str(FLAGS.num_best_directions)+'_epars_'+str(FLAGS.number_training_episodes_ars) elif FLAGS.switch_heuristic_regardless_of_oracle: checkpoint_dir += '_switch_heuristics_' if FLAGS.oracle_type == 'BR': oracle_flag_str = '' else: oracle_flag_str = '_hl_'+str(FLAGS.hidden_layer_size)+'_bs_'+str(FLAGS.batch_size)+'_nhl_'+str(FLAGS.n_hidden_layers) if FLAGS.oracle_type == 'DQN': oracle_flag_str += '_dqnlr_'+str(FLAGS.dqn_learning_rate)+'_tnuf_'+str(FLAGS.update_target_network_every)+'_lf_'+str(FLAGS.learn_every) else: oracle_flag_str += '_ls_'+str(FLAGS.loss_str)+'_nqbp_'+str(FLAGS.num_q_before_pi)+'_ec_'+str(FLAGS.entropy_cost)+'_clr_'+str(FLAGS.critic_learning_rate)+'_pilr_'+str(FLAGS.pi_learning_rate) checkpoint_dir = checkpoint_dir + oracle_flag_str+'_se_'+str(seed)+'_'+datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') checkpoint_dir = os.path.join(os.getcwd(),FLAGS.root_result_folder, checkpoint_dir) writer = SummaryWriter(logdir=checkpoint_dir+'/log') if FLAGS.sbatch_run: sys.stdout = open(checkpoint_dir+'/stdout.txt','w+') # Initialize oracle and agents with tf.Session() as sess: if FLAGS.oracle_type == "DQN": slow_oracle, agents, agent_kwargs = init_dqn_responder(sess, env) elif FLAGS.oracle_type == "PG": slow_oracle, agents, agent_kwargs = init_pg_responder(sess, env) elif FLAGS.oracle_type == "BR": slow_oracle, agents = init_br_responder(env) agent_kwargs = None elif FLAGS.oracle_type == "ARS": slow_oracle, agents = init_ars_responder(sess, env) agent_kwargs = None sess.run(tf.global_variables_initializer()) if FLAGS.switch_fast_slow: fast_oracle, agents = init_ars_responder(sess=None, env=env) oracle_list = [[], []] oracle_list[0].append(slow_oracle) oracle_list[0].append(fast_oracle) oracle_list[1] = [FLAGS.oracle_type,'ARS'] else: oracle_list = None gpsro_looper(env, slow_oracle, oracle_list, agents, writer, quiesce=FLAGS.quiesce, checkpoint_dir=checkpoint_dir, seed=seed, heuristic_list=heuristic_list) writer.close()
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") if FLAGS.seed is None: seed = np.random.randint(low=0, high=1e5) else: seed = FLAGS.seed np.random.seed(seed) random.seed(seed) tf.set_random_seed(seed) game_param = {"players": pyspiel.GameParameter(FLAGS.n_players)} checkpoint_dir = FLAGS.game_name if FLAGS.game_param is not None: for ele in FLAGS.game_param: ele_li = ele.split("=") game_param[ele_li[0]] = pyspiel.GameParameter(int(ele_li[1])) checkpoint_dir += '_'+ele_li[0]+'_'+ele_li[1] checkpoint_dir += '_' game = pyspiel.load_game_as_turn_based(FLAGS.game_name, game_param) env = rl_environment.Environment(game,seed=seed) env.reset() if not os.path.exists(FLAGS.root_result_folder): os.makedirs(FLAGS.root_result_folder) checkpoint_dir += str(FLAGS.n_players)+'_sims_'+str(FLAGS.sims_per_entry)+'_it_'+str(FLAGS.gpsro_iterations)+'_ep_'+str(FLAGS.number_training_episodes)+'_or_'+FLAGS.oracle_type+'_heur_'+FLAGS.meta_strategy_method if FLAGS.oracle_type == 'ARS': oracle_flag_str = '_arslr_'+str(FLAGS.ars_learning_rate)+'_arsn_'+str(FLAGS.noise)+'_arsnd_'+str(FLAGS.num_directions)+'_arsbd_'+str(FLAGS.num_best_directions) elif FLAGS.oracle_type == 'BR': oracle_flag_str = '' else: oracle_flag_str = '_hl_'+str(FLAGS.hidden_layer_size)+'_bs_'+str(FLAGS.batch_size)+'_nhl_'+str(FLAGS.n_hidden_layers) if FLAGS.oracle_type == 'DQN': oracle_flag_str += '_dqnlr_'+str(FLAGS.dqn_learning_rate)+'_tnuf_'+str(FLAGS.update_target_network_every)+'_lf_'+str(FLAGS.learn_every) else: oracle_flag_str += '_ls_'+str(FLAGS.loss_str)+'_nqbp_'+str(FLAGS.num_q_before_pi)+'_ec_'+str(FLAGS.entropy_cost)+'_clr_'+str(FLAGS.critic_learning_rate)+'_pilr_'+str(FLAGS.pi_learning_rate) checkpoint_dir = checkpoint_dir + oracle_flag_str+'_se_'+str(seed)+'_'+datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') checkpoint_dir = os.path.join(os.getcwd(),FLAGS.root_result_folder, checkpoint_dir) writer = SummaryWriter(logdir=checkpoint_dir+'/log') if FLAGS.sbatch_run: sys.stdout = open(checkpoint_dir+'/stdout.txt','w+') # Initialize oracle and agents with tf.Session() as sess: if FLAGS.oracle_type == "DQN": oracle, agents = init_dqn_responder(sess, env) elif FLAGS.oracle_type == "PG": oracle, agents = init_pg_responder(sess, env) elif FLAGS.oracle_type == "BR": oracle, agents = init_br_responder(env) elif FLAGS.oracle_type == "ARS": oracle, agents = init_ars_responder(sess, env) elif FLAGS.oracle_type == "ARS_parallel": oracle, agents = init_ars_parallel_responder(sess, env) # sess.run(tf.global_variables_initializer()) gpsro_looper(env, oracle, agents, writer, quiesce=FLAGS.quiesce, checkpoint_dir=checkpoint_dir, seed=seed) writer.close()