def test_dqn_fp_python_game(self): """Checks if fictitious play with DQN-based value function works.""" game = crowd_modelling.MFGCrowdModellingGame() dfp = fictitious_play.FictitiousPlay(game) uniform_policy = policy.UniformRandomPolicy(game) dist = distribution.DistributionPolicy(game, uniform_policy) envs = [ rl_environment.Environment(game, mfg_distribution=dist, mfg_population=p) for p in range(game.num_players()) ] dqn_agent = dqn.DQN( 0, state_representation_size=envs[0].observation_spec()["info_state"] [0], num_actions=envs[0].action_spec()["num_actions"], hidden_layers_sizes=[256, 128, 64], replay_buffer_capacity=100, batch_size=5, epsilon_start=0.02, epsilon_end=0.01) for _ in range(10): dfp.iteration(rl_br_agent=dqn_agent) dfp_policy = dfp.get_policy() nash_conv_dfp = nash_conv.NashConv(game, dfp_policy) self.assertAlmostEqual(nash_conv_dfp.nash_conv(), 1.0558451955622807)
def test_simple_game(self): game = pyspiel.load_efg_game(SIMPLE_EFG_DATA) env = rl_environment.Environment(game=game) agent = dqn.DQN( 0, state_representation_size=game.information_state_tensor_shape()[0], num_actions=game.num_distinct_actions(), hidden_layers_sizes=[16], replay_buffer_capacity=100, batch_size=5, epsilon_start=0.02, epsilon_end=0.01) total_reward = 0 for _ in range(100): time_step = env.reset() while not time_step.last(): agent_output = agent.step(time_step) time_step = env.step([agent_output.action]) total_reward += time_step.rewards[0] agent.step(time_step) self.assertGreaterEqual(total_reward, -100)
def test_run_tic_tac_toe(self): env = rl_environment.Environment("tic_tac_toe") state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] agents = [ dqn.DQN( # pylint: disable=g-complex-comprehension player_id, state_representation_size=state_size, num_actions=num_actions, hidden_layers_sizes=[16], replay_buffer_capacity=10, batch_size=5) for player_id in [0, 1] ] time_step = env.reset() while not time_step.last(): current_player = time_step.observations["current_player"] current_agent = agents[current_player] agent_output = current_agent.step(time_step) time_step = env.step([agent_output.action]) for agent in agents: agent.step(time_step)
def test_run_hanabi(self): # Hanabi is an optional game, so check we have it before running the test. game = "hanabi" if game not in pyspiel.registered_names(): return num_players = 3 env_configs = { "players": num_players, "max_life_tokens": 1, "colors": 2, "ranks": 3, "hand_size": 2, "max_information_tokens": 3, "discount": 0. } env = rl_environment.Environment(game, **env_configs) state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] agents = [ dqn.DQN( # pylint: disable=g-complex-comprehension player_id, state_representation_size=state_size, num_actions=num_actions, hidden_layers_sizes=[16], replay_buffer_capacity=10, batch_size=5) for player_id in range(num_players) ] time_step = env.reset() while not time_step.last(): current_player = time_step.observations["current_player"] agent_output = [agent.step(time_step) for agent in agents] time_step = env.step([agent_output[current_player].action]) for agent in agents: agent.step(time_step)
def main(unused_argv): logging.info("Loading %s", FLAGS.game_name) game = pyspiel.load_game(FLAGS.game_name, GAME_SETTINGS.get(FLAGS.game_name, {})) uniform_policy = policy.UniformRandomPolicy(game) mfg_dist = distribution.DistributionPolicy(game, uniform_policy) envs = [ rl_environment.Environment(game, distribution=mfg_dist, mfg_population=p) for p in range(game.num_players()) ] info_state_size = envs[0].observation_spec()["info_state"][0] num_actions = envs[0].action_spec()["num_actions"] hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes] kwargs = { "replay_buffer_capacity": FLAGS.replay_buffer_capacity, "min_buffer_size_to_learn": FLAGS.min_buffer_size_to_learn, "batch_size": FLAGS.batch_size, "learn_every": FLAGS.learn_every, "learning_rate": FLAGS.rl_learning_rate, "optimizer_str": FLAGS.optimizer_str, "loss_str": FLAGS.loss_str, "update_target_network_every": FLAGS.update_target_network_every, "discount_factor": FLAGS.discount_factor, "epsilon_decay_duration": FLAGS.epsilon_decay_duration, "epsilon_start": FLAGS.epsilon_start, "epsilon_end": FLAGS.epsilon_end, } # pylint: disable=g-complex-comprehension agents = [ dqn.DQN(idx, info_state_size, num_actions, hidden_layers_sizes, **kwargs) for idx in range(game.num_players()) ] joint_avg_policy = DQNPolicies(envs, agents) if FLAGS.use_checkpoints: for agent in agents: if agent.has_checkpoint(FLAGS.checkpoint_dir): agent.restore(FLAGS.checkpoint_dir) for ep in range(FLAGS.num_train_episodes): if (ep + 1) % FLAGS.eval_every == 0: losses = [agent.loss for agent in agents] logging.info("Losses: %s", losses) nash_conv_obj = nash_conv.NashConv(game, uniform_policy) print( str(ep + 1) + " Exact Best Response to Uniform " + str(nash_conv_obj.br_values())) pi_value = policy_value.PolicyValue(game, mfg_dist, joint_avg_policy) print( str(ep + 1) + " DQN Best Response to Uniform " + str([ pi_value.eval_state(state) for state in game.new_initial_states() ])) if FLAGS.use_checkpoints: for agent in agents: agent.save(FLAGS.checkpoint_dir) logging.info("_____________________________________________") for p in range(game.num_players()): time_step = envs[p].reset() while not time_step.last(): agent_output = agents[p].step(time_step) action_list = [agent_output.action] time_step = envs[p].step(action_list) # Episode is over, step all agents with final info state. agents[p].step(time_step)
def __init__(self, player_id, state_representation_size, num_actions, hidden_layers_sizes, reservoir_buffer_capacity, anticipatory_param, batch_size=128, rl_learning_rate=0.01, sl_learning_rate=0.01, min_buffer_size_to_learn=1000, learn_every=64, optimizer_str="sgd", **kwargs): """Initialize the `NFSP` agent.""" self.player_id = player_id self._num_actions = num_actions self._layer_sizes = hidden_layers_sizes self._batch_size = batch_size self._learn_every = learn_every self._anticipatory_param = anticipatory_param self._min_buffer_size_to_learn = min_buffer_size_to_learn self._reservoir_buffer = ReservoirBuffer(reservoir_buffer_capacity) self._prev_timestep = None self._prev_action = None # Step counter to keep track of learning. self._step_counter = 0 # Inner RL agent kwargs.update({ "batch_size": batch_size, "learning_rate": rl_learning_rate, "learn_every": learn_every, "min_buffer_size_to_learn": min_buffer_size_to_learn, "optimizer_str": optimizer_str, }) self._rl_agent = dqn.DQN(player_id, state_representation_size, num_actions, hidden_layers_sizes, **kwargs) # Keep track of the last training loss achieved in an update step. self._last_rl_loss_value = lambda: self._rl_agent.loss self._last_sl_loss_value = None # Average policy network. def network(x): mlp = hk.nets.MLP(self._layer_sizes + [num_actions]) return mlp(x) self.hk_avg_network = hk.without_apply_rng(hk.transform(network)) def avg_network_policy(param, info_state): action_values = self.hk_avg_network.apply(param, info_state) action_probs = jax.nn.softmax(action_values, axis=1) return action_values, action_probs self._avg_network_policy = jax.jit(avg_network_policy) rng = jax.random.PRNGKey(42) x = jnp.ones([1, state_representation_size]) self.params_avg_network = self.hk_avg_network.init(rng, x) self.params_avg_network = jax.device_put(self.params_avg_network) self._savers = [ ("q_network", self._rl_agent.params_q_network), ("avg_network", self.params_avg_network) ] if optimizer_str == "adam": opt_init, opt_update = optax.chain( optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), optax.scale(sl_learning_rate)) elif optimizer_str == "sgd": opt_init, opt_update = optax.sgd(sl_learning_rate) else: raise ValueError("Not implemented. Choose from ['adam', 'sgd'].") self._opt_update_fn = self._get_update_func(opt_update) self._opt_state = opt_init(self.params_avg_network) self._loss_and_grad = jax.value_and_grad(self._loss_avg, has_aux=False) self._sample_episode_policy() self._jit_update = jax.jit(self.get_update())
def main(unused_argv): logging.info("Loading %s", FLAGS.game_name) game = pyspiel.load_game(FLAGS.game_name, GAME_SETTINGS.get(FLAGS.game_name, {})) uniform_policy = policy.UniformRandomPolicy(game) mfg_dist = distribution.DistributionPolicy(game, uniform_policy) envs = [ rl_environment.Environment(game, mfg_distribution=mfg_dist, mfg_population=p) for p in range(game.num_players()) ] info_state_size = envs[0].observation_spec()["info_state"][0] num_actions = envs[0].action_spec()["num_actions"] hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes] kwargs = { "replay_buffer_capacity": FLAGS.replay_buffer_capacity, "min_buffer_size_to_learn": FLAGS.min_buffer_size_to_learn, "batch_size": FLAGS.batch_size, "learn_every": FLAGS.learn_every, "learning_rate": FLAGS.rl_learning_rate, "optimizer_str": FLAGS.optimizer_str, "loss_str": FLAGS.loss_str, "update_target_network_every": FLAGS.update_target_network_every, "discount_factor": FLAGS.discount_factor, "epsilon_decay_duration": FLAGS.epsilon_decay_duration, "epsilon_start": FLAGS.epsilon_start, "epsilon_end": FLAGS.epsilon_end, } # pylint: disable=g-complex-comprehension agents = [ dqn.DQN(idx, info_state_size, num_actions, hidden_layers_sizes, **kwargs) for idx in range(game.num_players()) ] joint_avg_policy = rl_agent_policy.JointRLAgentPolicy( game, {idx: agent for idx, agent in enumerate(agents)}, envs[0].use_observation) if FLAGS.use_checkpoints: for agent in agents: if agent.has_checkpoint(FLAGS.checkpoint_dir): agent.restore(FLAGS.checkpoint_dir) # Metrics writer will also log the metrics to stderr. just_logging = FLAGS.logdir is None or jax.host_id() > 0 writer = metric_writers.create_default_writer(FLAGS.logdir, just_logging=just_logging) # Save the parameters. writer.write_hparams(kwargs) for ep in range(1, FLAGS.num_train_episodes + 1): if ep % FLAGS.eval_every == 0: writer.write_scalars( ep, { f"agent{i}/loss": float(agent.loss) for i, agent in enumerate(agents) }) initial_states = game.new_initial_states() # Exact best response to uniform. nash_conv_obj = nash_conv.NashConv(game, uniform_policy) writer.write_scalars( ep, { f"exact_br/{state}": value for state, value in zip(initial_states, nash_conv_obj.br_values()) }) # DQN best response to uniform. pi_value = policy_value.PolicyValue(game, mfg_dist, joint_avg_policy) writer.write_scalars( ep, { f"dqn_br/{state}": pi_value.eval_state(state) for state in initial_states }) if FLAGS.use_checkpoints: for agent in agents: agent.save(FLAGS.checkpoint_dir) for p in range(game.num_players()): time_step = envs[p].reset() while not time_step.last(): agent_output = agents[p].step(time_step) action_list = [agent_output.action] time_step = envs[p].step(action_list) # Episode is over, step all agents with final info state. agents[p].step(time_step) # Make sure all values were written. writer.flush()