Ejemplo n.º 1
0
    def test_dqn_fp_python_game(self):
        """Checks if fictitious play with DQN-based value function works."""
        game = crowd_modelling.MFGCrowdModellingGame()
        dfp = fictitious_play.FictitiousPlay(game)

        uniform_policy = policy.UniformRandomPolicy(game)
        dist = distribution.DistributionPolicy(game, uniform_policy)
        envs = [
            rl_environment.Environment(game,
                                       mfg_distribution=dist,
                                       mfg_population=p)
            for p in range(game.num_players())
        ]
        dqn_agent = dqn.DQN(
            0,
            state_representation_size=envs[0].observation_spec()["info_state"]
            [0],
            num_actions=envs[0].action_spec()["num_actions"],
            hidden_layers_sizes=[256, 128, 64],
            replay_buffer_capacity=100,
            batch_size=5,
            epsilon_start=0.02,
            epsilon_end=0.01)

        for _ in range(10):
            dfp.iteration(rl_br_agent=dqn_agent)

        dfp_policy = dfp.get_policy()
        nash_conv_dfp = nash_conv.NashConv(game, dfp_policy)

        self.assertAlmostEqual(nash_conv_dfp.nash_conv(), 1.0558451955622807)
Ejemplo n.º 2
0
    def test_simple_game(self):
        game = pyspiel.load_efg_game(SIMPLE_EFG_DATA)
        env = rl_environment.Environment(game=game)
        agent = dqn.DQN(
            0,
            state_representation_size=game.information_state_tensor_shape()[0],
            num_actions=game.num_distinct_actions(),
            hidden_layers_sizes=[16],
            replay_buffer_capacity=100,
            batch_size=5,
            epsilon_start=0.02,
            epsilon_end=0.01)
        total_reward = 0

        for _ in range(100):
            time_step = env.reset()
            while not time_step.last():
                agent_output = agent.step(time_step)
                time_step = env.step([agent_output.action])
                total_reward += time_step.rewards[0]
            agent.step(time_step)
        self.assertGreaterEqual(total_reward, -100)
Ejemplo n.º 3
0
    def test_run_tic_tac_toe(self):
        env = rl_environment.Environment("tic_tac_toe")
        state_size = env.observation_spec()["info_state"][0]
        num_actions = env.action_spec()["num_actions"]

        agents = [
            dqn.DQN(  # pylint: disable=g-complex-comprehension
                player_id,
                state_representation_size=state_size,
                num_actions=num_actions,
                hidden_layers_sizes=[16],
                replay_buffer_capacity=10,
                batch_size=5) for player_id in [0, 1]
        ]
        time_step = env.reset()
        while not time_step.last():
            current_player = time_step.observations["current_player"]
            current_agent = agents[current_player]
            agent_output = current_agent.step(time_step)
            time_step = env.step([agent_output.action])

        for agent in agents:
            agent.step(time_step)
Ejemplo n.º 4
0
    def test_run_hanabi(self):
        # Hanabi is an optional game, so check we have it before running the test.
        game = "hanabi"
        if game not in pyspiel.registered_names():
            return

        num_players = 3
        env_configs = {
            "players": num_players,
            "max_life_tokens": 1,
            "colors": 2,
            "ranks": 3,
            "hand_size": 2,
            "max_information_tokens": 3,
            "discount": 0.
        }
        env = rl_environment.Environment(game, **env_configs)
        state_size = env.observation_spec()["info_state"][0]
        num_actions = env.action_spec()["num_actions"]

        agents = [
            dqn.DQN(  # pylint: disable=g-complex-comprehension
                player_id,
                state_representation_size=state_size,
                num_actions=num_actions,
                hidden_layers_sizes=[16],
                replay_buffer_capacity=10,
                batch_size=5) for player_id in range(num_players)
        ]
        time_step = env.reset()
        while not time_step.last():
            current_player = time_step.observations["current_player"]
            agent_output = [agent.step(time_step) for agent in agents]
            time_step = env.step([agent_output[current_player].action])

        for agent in agents:
            agent.step(time_step)
Ejemplo n.º 5
0
def main(unused_argv):
    logging.info("Loading %s", FLAGS.game_name)
    game = pyspiel.load_game(FLAGS.game_name,
                             GAME_SETTINGS.get(FLAGS.game_name, {}))
    uniform_policy = policy.UniformRandomPolicy(game)
    mfg_dist = distribution.DistributionPolicy(game, uniform_policy)

    envs = [
        rl_environment.Environment(game,
                                   distribution=mfg_dist,
                                   mfg_population=p)
        for p in range(game.num_players())
    ]
    info_state_size = envs[0].observation_spec()["info_state"][0]
    num_actions = envs[0].action_spec()["num_actions"]

    hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes]
    kwargs = {
        "replay_buffer_capacity": FLAGS.replay_buffer_capacity,
        "min_buffer_size_to_learn": FLAGS.min_buffer_size_to_learn,
        "batch_size": FLAGS.batch_size,
        "learn_every": FLAGS.learn_every,
        "learning_rate": FLAGS.rl_learning_rate,
        "optimizer_str": FLAGS.optimizer_str,
        "loss_str": FLAGS.loss_str,
        "update_target_network_every": FLAGS.update_target_network_every,
        "discount_factor": FLAGS.discount_factor,
        "epsilon_decay_duration": FLAGS.epsilon_decay_duration,
        "epsilon_start": FLAGS.epsilon_start,
        "epsilon_end": FLAGS.epsilon_end,
    }

    # pylint: disable=g-complex-comprehension
    agents = [
        dqn.DQN(idx, info_state_size, num_actions, hidden_layers_sizes,
                **kwargs) for idx in range(game.num_players())
    ]
    joint_avg_policy = DQNPolicies(envs, agents)
    if FLAGS.use_checkpoints:
        for agent in agents:
            if agent.has_checkpoint(FLAGS.checkpoint_dir):
                agent.restore(FLAGS.checkpoint_dir)

    for ep in range(FLAGS.num_train_episodes):
        if (ep + 1) % FLAGS.eval_every == 0:
            losses = [agent.loss for agent in agents]
            logging.info("Losses: %s", losses)
            nash_conv_obj = nash_conv.NashConv(game, uniform_policy)
            print(
                str(ep + 1) + " Exact Best Response to Uniform " +
                str(nash_conv_obj.br_values()))
            pi_value = policy_value.PolicyValue(game, mfg_dist,
                                                joint_avg_policy)
            print(
                str(ep + 1) + " DQN Best Response to Uniform " + str([
                    pi_value.eval_state(state)
                    for state in game.new_initial_states()
                ]))
            if FLAGS.use_checkpoints:
                for agent in agents:
                    agent.save(FLAGS.checkpoint_dir)
            logging.info("_____________________________________________")

        for p in range(game.num_players()):
            time_step = envs[p].reset()
            while not time_step.last():
                agent_output = agents[p].step(time_step)
                action_list = [agent_output.action]
                time_step = envs[p].step(action_list)

            # Episode is over, step all agents with final info state.
            agents[p].step(time_step)
Ejemplo n.º 6
0
  def __init__(self,
               player_id,
               state_representation_size,
               num_actions,
               hidden_layers_sizes,
               reservoir_buffer_capacity,
               anticipatory_param,
               batch_size=128,
               rl_learning_rate=0.01,
               sl_learning_rate=0.01,
               min_buffer_size_to_learn=1000,
               learn_every=64,
               optimizer_str="sgd",
               **kwargs):
    """Initialize the `NFSP` agent."""
    self.player_id = player_id
    self._num_actions = num_actions
    self._layer_sizes = hidden_layers_sizes
    self._batch_size = batch_size
    self._learn_every = learn_every
    self._anticipatory_param = anticipatory_param
    self._min_buffer_size_to_learn = min_buffer_size_to_learn

    self._reservoir_buffer = ReservoirBuffer(reservoir_buffer_capacity)
    self._prev_timestep = None
    self._prev_action = None

    # Step counter to keep track of learning.
    self._step_counter = 0

    # Inner RL agent
    kwargs.update({
        "batch_size": batch_size,
        "learning_rate": rl_learning_rate,
        "learn_every": learn_every,
        "min_buffer_size_to_learn": min_buffer_size_to_learn,
        "optimizer_str": optimizer_str,
    })
    self._rl_agent = dqn.DQN(player_id, state_representation_size,
                             num_actions, hidden_layers_sizes, **kwargs)

    # Keep track of the last training loss achieved in an update step.
    self._last_rl_loss_value = lambda: self._rl_agent.loss
    self._last_sl_loss_value = None

    # Average policy network.
    def network(x):
      mlp = hk.nets.MLP(self._layer_sizes + [num_actions])
      return mlp(x)

    self.hk_avg_network = hk.without_apply_rng(hk.transform(network))

    def avg_network_policy(param, info_state):
      action_values = self.hk_avg_network.apply(param, info_state)
      action_probs = jax.nn.softmax(action_values, axis=1)
      return action_values, action_probs

    self._avg_network_policy = jax.jit(avg_network_policy)

    rng = jax.random.PRNGKey(42)
    x = jnp.ones([1, state_representation_size])
    self.params_avg_network = self.hk_avg_network.init(rng, x)
    self.params_avg_network = jax.device_put(self.params_avg_network)

    self._savers = [
        ("q_network", self._rl_agent.params_q_network),
        ("avg_network", self.params_avg_network)
    ]

    if optimizer_str == "adam":
      opt_init, opt_update = optax.chain(
          optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
          optax.scale(sl_learning_rate))
    elif optimizer_str == "sgd":
      opt_init, opt_update = optax.sgd(sl_learning_rate)
    else:
      raise ValueError("Not implemented. Choose from ['adam', 'sgd'].")
    self._opt_update_fn = self._get_update_func(opt_update)
    self._opt_state = opt_init(self.params_avg_network)
    self._loss_and_grad = jax.value_and_grad(self._loss_avg, has_aux=False)

    self._sample_episode_policy()
    self._jit_update = jax.jit(self.get_update())
Ejemplo n.º 7
0
def main(unused_argv):
    logging.info("Loading %s", FLAGS.game_name)
    game = pyspiel.load_game(FLAGS.game_name,
                             GAME_SETTINGS.get(FLAGS.game_name, {}))
    uniform_policy = policy.UniformRandomPolicy(game)
    mfg_dist = distribution.DistributionPolicy(game, uniform_policy)

    envs = [
        rl_environment.Environment(game,
                                   mfg_distribution=mfg_dist,
                                   mfg_population=p)
        for p in range(game.num_players())
    ]
    info_state_size = envs[0].observation_spec()["info_state"][0]
    num_actions = envs[0].action_spec()["num_actions"]

    hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes]
    kwargs = {
        "replay_buffer_capacity": FLAGS.replay_buffer_capacity,
        "min_buffer_size_to_learn": FLAGS.min_buffer_size_to_learn,
        "batch_size": FLAGS.batch_size,
        "learn_every": FLAGS.learn_every,
        "learning_rate": FLAGS.rl_learning_rate,
        "optimizer_str": FLAGS.optimizer_str,
        "loss_str": FLAGS.loss_str,
        "update_target_network_every": FLAGS.update_target_network_every,
        "discount_factor": FLAGS.discount_factor,
        "epsilon_decay_duration": FLAGS.epsilon_decay_duration,
        "epsilon_start": FLAGS.epsilon_start,
        "epsilon_end": FLAGS.epsilon_end,
    }

    # pylint: disable=g-complex-comprehension
    agents = [
        dqn.DQN(idx, info_state_size, num_actions, hidden_layers_sizes,
                **kwargs) for idx in range(game.num_players())
    ]
    joint_avg_policy = rl_agent_policy.JointRLAgentPolicy(
        game, {idx: agent
               for idx, agent in enumerate(agents)}, envs[0].use_observation)
    if FLAGS.use_checkpoints:
        for agent in agents:
            if agent.has_checkpoint(FLAGS.checkpoint_dir):
                agent.restore(FLAGS.checkpoint_dir)

    # Metrics writer will also log the metrics to stderr.
    just_logging = FLAGS.logdir is None or jax.host_id() > 0
    writer = metric_writers.create_default_writer(FLAGS.logdir,
                                                  just_logging=just_logging)

    # Save the parameters.
    writer.write_hparams(kwargs)

    for ep in range(1, FLAGS.num_train_episodes + 1):
        if ep % FLAGS.eval_every == 0:
            writer.write_scalars(
                ep, {
                    f"agent{i}/loss": float(agent.loss)
                    for i, agent in enumerate(agents)
                })

            initial_states = game.new_initial_states()

            # Exact best response to uniform.
            nash_conv_obj = nash_conv.NashConv(game, uniform_policy)
            writer.write_scalars(
                ep, {
                    f"exact_br/{state}": value
                    for state, value in zip(initial_states,
                                            nash_conv_obj.br_values())
                })

            # DQN best response to uniform.
            pi_value = policy_value.PolicyValue(game, mfg_dist,
                                                joint_avg_policy)
            writer.write_scalars(
                ep, {
                    f"dqn_br/{state}": pi_value.eval_state(state)
                    for state in initial_states
                })

            if FLAGS.use_checkpoints:
                for agent in agents:
                    agent.save(FLAGS.checkpoint_dir)

        for p in range(game.num_players()):
            time_step = envs[p].reset()
            while not time_step.last():
                agent_output = agents[p].step(time_step)
                action_list = [agent_output.action]
                time_step = envs[p].step(action_list)

            # Episode is over, step all agents with final info state.
            agents[p].step(time_step)

    # Make sure all values were written.
    writer.flush()