Ejemplo n.º 1
0
def main(_):
    game = pyspiel.load_game(FLAGS.game)
    evaluator = pyspiel.RandomRolloutEvaluator(1, SEED)
    min_expl = game.max_utility() - game.min_utility()

    print("{:>5} {:>10} {:>50} {:>20}".format("max_sims", "uct_c",
                                              "final_policy_type",
                                              "exploitability"))
    for max_simulations in [10, 100, 1000, 10000]:
        for uct_c in [0.2, 0.5, 1.0, 2.0, 4.0]:  # These values are for Kuhn.
            for final_policy_type in [
                    pyspiel.ISMCTSFinalPolicyType.NORMALIZED_VISIT_COUNT,
                    pyspiel.ISMCTSFinalPolicyType.MAX_VISIT_COUNT,
                    pyspiel.ISMCTSFinalPolicyType.MAX_VALUE
            ]:
                tabular_policy = policy.TabularPolicy(game)
                bot = pyspiel.ISMCTSBot(SEED, evaluator, uct_c,
                                        max_simulations, -1, final_policy_type,
                                        False, False)
                searched = {}
                construct_is_mcts_policy(game, game.new_initial_state(),
                                         tabular_policy, bot, searched)
                expl = exploitability.exploitability(game, tabular_policy)
                print("{:>5} {:>10} {:>50} {:>20}".format(
                    max_simulations, uct_c, str(final_policy_type), expl))
                if expl < min_expl:
                    min_expl = expl
    print("Min expl: {}".format(min_expl))
Ejemplo n.º 2
0
def main(_):
    game = pyspiel.load_game_as_turn_based(game_, )
    cfr_solver = cfr.CFRSolver(game)

    print("policy_initial:",
          cfr_solver.current_policy().action_probability_array)
    for i in range(FLAGS.iterations):

        if i % FLAGS.print_freq == 0:
            conv = exploitability.exploitability(game,
                                                 cfr_solver.average_policy())
            print("Iteration {} exploitability {}".format(i, conv))
            print("Iteration{}".format(i))
            print("policy_av:",
                  cfr_solver.average_policy().action_probability_array)
            print("policy_cr:",
                  cfr_solver.current_policy().action_probability_array)

        cfr_solver.evaluate_and_update_policy()
        write_csv(dir_ + game_ + "_" + algo_name + "_av.csv",
                  cfr_solver.average_policy().action_probability_array[0])
        write_csv(dir_ + game_ + "_" + algo_name + "_av.csv",
                  cfr_solver.average_policy().action_probability_array[1])
        write_csv(dir_ + game_ + "_" + algo_name + "_cr.csv",
                  cfr_solver.current_policy().action_probability_array[0])
        write_csv(dir_ + game_ + "_" + algo_name + "_cr.csv",
                  cfr_solver.current_policy().action_probability_array[1])
Ejemplo n.º 3
0
def nfsp_measure_exploitability_nonlstm(rllib_policies: List[Policy],
                                        poker_game_version: str,
                                        open_spiel_env_config: dict = None):
    if open_spiel_env_config is None:
        if poker_game_version in ["kuhn_poker", "leduc_poker"]:
            open_spiel_env_config = {
                "players": pyspiel.GameParameter(2)
            }
        else:
            open_spiel_env_config = {}

    open_spiel_env_config = {k: pyspiel.GameParameter(v) if not isinstance(v, pyspiel.GameParameter) else v for k, v in
                             open_spiel_env_config.items()}

    openspiel_game = pyspiel.load_game(poker_game_version, open_spiel_env_config)
    if poker_game_version == "oshi_zumo":
        openspiel_game = pyspiel.convert_to_turn_based(openspiel_game)

    opnsl_policies = []
    for rllib_policy in rllib_policies:
        openspiel_policy = openspiel_policy_from_nonlstm_rllib_policy(openspiel_game=openspiel_game,
                                                                      rllib_policy=rllib_policy,
                                                                      game_version=poker_game_version,
                                                                      game_parameters=open_spiel_env_config,
        )
        opnsl_policies.append(openspiel_policy)

    nfsp_policy = JointPlayerPolicy(game=openspiel_game, policies=opnsl_policies)

    # Exploitability is NashConv / num_players
    if poker_game_version == "universal_poker":
        print("Measuring exploitability for universal_poker policy. This will take a while...")
    exploitability_result = exploitability(game=openspiel_game, policy=nfsp_policy)
    return exploitability_result
Ejemplo n.º 4
0
 def test_exploitability_is_zero_on_nash(self, alpha):
     # A similar test exists in:
     # open_spiel/python/algorithms/exploitability_test.py
     game = pyspiel.load_game("kuhn_poker")
     policy = data.kuhn_nash_equilibrium(alpha=alpha)
     expl = exploitability.exploitability(game, policy)
     self.assertAlmostEqual(0, expl)
Ejemplo n.º 5
0
def print_algorithm_results(game, callable_policy, algorithm_name):
    print(algorithm_name.upper())
    tabular_policy = tabular_policy_from_callable(game, callable_policy)
    policy_exploitability = exploitability(game, tabular_policy)
    policy_nashconv = nash_conv(game, tabular_policy)
    print("exploitability = {}".format(policy_exploitability))
    print("nashconv = {}".format(policy_nashconv))
    def solve(self):
        """Solution logic for Deep CFR."""
        advantage_losses = collections.defaultdict(list)
        start = datetime.now()
        expl_idx = []
        expl_hist = []
        for it in range(self._num_iterations):
            if (it % self._eval_freq == 0) and it != 0:
                conv = self.get_exploitabilitiy()
                elapsed = datetime.now() - start
                print(
                    "Episode {}/{}, running for {} seconds - Exploitability = {}"
                    .format(it, self._num_iterations, elapsed.seconds, conv))
                expl_idx.append(it)
                expl_hist.append(conv)
            for p in range(self._num_players):
                for _ in range(self._num_traversals):
                    self._traverse_game_tree(self._root_node, p)
                self.reinitialize_advantage_networks()
                # Re-initialize advantage networks and train from scratch.
                advantage_losses[p].append(self._learn_advantage_network(p))
            self._iteration += 1
        # Train policy network.
        policy_loss = self._learn_strategy_network()

        conv = exploitability.exploitability(
            self._game,
            policy.PolicyFromCallable(self._game, self.action_probabilities))
        print("Final exploitability: {}".format(conv))
        return self._policy_network, advantage_losses, policy_loss, expl_idx, expl_hist
def nfsp_measure_exploitability_nonlstm(rllib_p0_and_p1_policies,
                                        poker_game_version):
    if poker_game_version in [KUHN_POKER, LEDUC_POKER]:
        open_spiel_env_config = {"players": pyspiel.GameParameter(2)}
    else:
        open_spiel_env_config = {}

    openspiel_game = pyspiel.load_game(poker_game_version,
                                       open_spiel_env_config)
    openspiel_env = Environment(poker_game_version, open_spiel_env_config)

    openspiel_policies = []

    for rllib_policy in rllib_p0_and_p1_policies:

        if not isinstance(rllib_policy, OSPolicy):
            openspiel_policy = openspiel_policy_from_nonlstm_rllib_policy(
                openspiel_game=openspiel_game,
                poker_game_version=poker_game_version,
                rllib_policy=rllib_policy)
        else:
            openspiel_policy = rllib_policy

        openspiel_policies.append(openspiel_policy)

    nfsp_os_policy = NFSPPolicies(env=openspiel_env,
                                  nfsp_policies=openspiel_policies)

    # Exploitability is NashConv / num_players
    exploitability_result = exploitability(game=openspiel_game,
                                           policy=nfsp_os_policy)
    return exploitability_result
Ejemplo n.º 8
0
def cfr_train(unused_arg):
    exploit_history = list()
    exploit_idx = list()

    tf.enable_eager_execution()
    game = pyspiel.load_game(FLAGS.game, {"players": pyspiel.GameParameter(2)})
    agent_name = "cfr"
    cfr_solver = cfr.CFRSolver(game)
    checkpoint = datetime.now()
    for ep in range(FLAGS.episodes):
        cfr_solver.evaluate_and_update_policy()
        if ep % 100 == 0:
            delta = datetime.now() - checkpoint
            conv = exploitability.exploitability(game,
                                                 cfr_solver.average_policy())
            exploit_idx.append(ep)
            exploit_history.append(conv)
            print(
                "Iteration {} exploitability {} - {} seconds since last checkpoint"
                .format(ep, conv, delta.seconds))
            checkpoint = datetime.now()

    pickle.dump([exploit_idx, exploit_history],
                open(
                    FLAGS.game + "_" + agent_name + "_" + str(FLAGS.episodes) +
                    ".dat", "wb"))

    now = datetime.now()
    policy = cfr_solver.average_policy()
    agent_name = "cfr"
    for pid in [1, 2]:
        policy_to_csv(
            game, policy,
            f"policies/policy_" + now.strftime("%m-%d-%Y_%H-%M") + "_" +
            agent_name + "_" + str(pid + 1) + "_+" + str(ep) + "episodes.csv")
Ejemplo n.º 9
0
def xfsp_train(_):
    exploit_history = list()
    exploit_idx = list()
    game = pyspiel.load_game(FLAGS.game, {"players": pyspiel.GameParameter(2)})
    fsp_solver = fictitious_play.XFPSolver(game)
    checkpoint = datetime.now()
    for ep in range(FLAGS.episodes):
        if (ep % 1000) == 0:
            delta = datetime.now() - checkpoint
            pol = policy.PolicyFromCallable(
                game, fsp_solver.average_policy_callable())
            conv = exploitability.exploitability(game, pol)
            exploit_history.append(conv)
            exploit_idx.append(ep)
            print(
                "[XFSP] Iteration {} exploitability {} - {} seconds since last checkpoint"
                .format(ep, conv, delta.seconds))
            checkpoint = datetime.now()

        fsp_solver.iteration()

    agent_name = "xfsp"
    pickle.dump([exploit_idx, exploit_history],
                open(
                    FLAGS.game + "_" + agent_name + "_" + str(FLAGS.episodes) +
                    ".dat", "wb"))

    pol = policy.PolicyFromCallable(game, fsp_solver.average_policy_callable())
    for pid in [1, 2]:
        policy_to_csv(
            game, pol, f"policies/policy_" + now.strftime("%m-%d-%Y_%H-%M") +
            "_" + agent_name + "_" + str(pid + 1) + "_+" +
            str(FLAGS.episodes) + "episodes.csv")
Ejemplo n.º 10
0
def get_algo_metrics(algo_policies, game):
    print("Extracting metrics...")
    algo_exploitabilities = {}
    algo_nashconvs = {}
    for key in algo_policies:
        algo_exploitabilities[key] = exploitability(game, algo_policies[key])
        algo_nashconvs[key] = nash_conv(game, algo_policies[key])
    return algo_exploitabilities, algo_nashconvs
    def get_exploitabilitiy(self):
        #Define placeholders
        iter_ph = tf.placeholder(shape=[None, 1],
                                 dtype=tf.float32,
                                 name="iter_ph")
        action_probs_ph = tf.placeholder(shape=[None, self._num_actions],
                                         dtype=tf.float32,
                                         name="action_probs_ph")
        info_state_ph = tf.placeholder(shape=[None, self._embedding_size],
                                       dtype=tf.float32,
                                       name="info_state_ph")

        policy_network = snt.nets.MLP(
            list(self._policy_network_layers) + [self._num_actions])
        action_logits = policy_network(info_state_ph)
        # Illegal actions are handled in the traversal code where expected payoff
        # and sampled regret is computed from the advantage networks.
        action_probs = tf.nn.softmax(action_logits)
        loss_policy = tf.reduce_mean(
            tf.losses.mean_squared_error(
                labels=tf.math.sqrt(iter_ph) * action_probs_ph,
                predictions=tf.math.sqrt(iter_ph) * action_probs))
        optimizer_policy = tf.train.AdamOptimizer(
            learning_rate=self._learning_rate)
        learn_step_policy = optimizer_policy.minimize(loss_policy)

        self._session.run(tf.global_variables_initializer())

        def _local_action_probabilities(state):
            """Returns action probabilities dict for a single batch."""
            cur_player = state.current_player()
            legal_actions = state.legal_actions(cur_player)
            info_state_vector = np.array(state.information_state_tensor())
            if len(info_state_vector.shape) == 1:
                info_state_vector = np.expand_dims(info_state_vector, axis=0)
            probs = self._session.run(
                action_probs, feed_dict={info_state_ph: info_state_vector})
            return {action: probs[0][action] for action in legal_actions}

        info_states_l = []
        action_probs_l = []
        iterations_l = []
        for s in self._strategy_memories.sample(self._batch_size_strategy):
            info_states_l.append(s.info_state)
            action_probs_l.append(s.strategy_action_probs)
            iterations_l.append([s.iteration])
        self._session.run(
            [loss_policy, learn_step_policy],
            feed_dict={
                info_state_ph: np.array(info_states_l),
                action_probs_ph: np.array(np.squeeze(action_probs_l)),
                iter_ph: np.array(iterations_l),
            })

        conv = exploitability.exploitability(
            self._game,
            policy.PolicyFromCallable(self._game, _local_action_probabilities))
        return conv
Ejemplo n.º 12
0
 def test_exploitability_on_kuhn_poker_uniform_random(self):
     # NashConv of uniform random test_policy from (found on Google books):
     # https://link.springer.com/chapter/10.1007/978-3-319-75931-9_5
     game = pyspiel.load_game("kuhn_poker")
     test_policy = policy.UniformRandomPolicy(game)
     expected_nash_conv = 11 / 12
     self.assertAlmostEqual(
         exploitability.exploitability(game, test_policy),
         expected_nash_conv / 2)
Ejemplo n.º 13
0
def main(_):
    game = pyspiel.load_game(FLAGS.game, {"players": FLAGS.players})
    xfp_solver = fictitious_play.XFPSolver(game)
    for i in range(FLAGS.iterations):
        xfp_solver.iteration()
        conv = exploitability.exploitability(game, xfp_solver.average_policy())
        if i % FLAGS.print_freq == 0:
            print("Iteration: {} Conv: {}".format(i, conv))
            sys.stdout.flush()
Ejemplo n.º 14
0
def main(unused_argv):
  logging.info("Loading %s", FLAGS.game_name)
  game = FLAGS.game_name
  num_players = FLAGS.num_players

  env_configs = {"players": num_players}
  env = rl_environment.Environment(game, **env_configs)
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes]
  kwargs = {
      "replay_buffer_capacity": FLAGS.replay_buffer_capacity,
      "reservoir_buffer_capacity": FLAGS.reservoir_buffer_capacity,
      "min_buffer_size_to_learn": FLAGS.min_buffer_size_to_learn,
      "anticipatory_param": FLAGS.anticipatory_param,
      "batch_size": FLAGS.batch_size,
      "learn_every": FLAGS.learn_every,
      "rl_learning_rate": FLAGS.rl_learning_rate,
      "sl_learning_rate": FLAGS.sl_learning_rate,
      "optimizer_str": FLAGS.optimizer_str,
      "loss_str": FLAGS.loss_str,
      "update_target_network_every": FLAGS.update_target_network_every,
      "discount_factor": FLAGS.discount_factor,
      "epsilon_decay_duration": FLAGS.epsilon_decay_duration,
      "epsilon_start": FLAGS.epsilon_start,
      "epsilon_end": FLAGS.epsilon_end,
  }

  with tf.Session() as sess:
    # pylint: disable=g-complex-comprehension
    agents = [
        nfsp.NFSP(sess, idx, info_state_size, num_actions, hidden_layers_sizes,
                  **kwargs) for idx in range(num_players)
    ]
    expl_policies_avg = NFSPPolicies(env, agents, nfsp.MODE.average_policy)

    sess.run(tf.global_variables_initializer())
    for ep in range(FLAGS.num_train_episodes):
      if (ep + 1) % FLAGS.eval_every == 0:
        losses = [agent.loss for agent in agents]
        logging.info("Losses: %s", losses)
        expl = exploitability.exploitability(env.game, expl_policies_avg)
        logging.info("[%s] Exploitability AVG %s", ep + 1, expl)
        logging.info("_____________________________________________")

      time_step = env.reset()
      while not time_step.last():
        player_id = time_step.observations["current_player"]
        agent_output = agents[player_id].step(time_step)
        action_list = [agent_output.action]
        time_step = env.step(action_list)

      # Episode is over, step all agents with final info state.
      for agent in agents:
        agent.step(time_step)
Ejemplo n.º 15
0
def main(_):
  game = pyspiel.load_game(FLAGS.game,
                           {"players": pyspiel.GameParameter(FLAGS.players)})
  cfr_solver = cfr.CFRSolver(game)

  for i in range(FLAGS.iterations):
    cfr_solver.evaluate_and_update_policy()
    if i % FLAGS.print_freq == 0:
      conv = exploitability.exploitability(game, cfr_solver.average_policy())
      print("Iteration {} exploitability {}".format(i, conv))
Ejemplo n.º 16
0
def main(_):
    game = pyspiel.load_game(FLAGS.game)
    discounted_cfr_solver = discounted_cfr.DCFRSolver(game)

    for i in range(FLAGS.iterations):
        discounted_cfr_solver.evaluate_and_update_policy()
        if i % FLAGS.print_freq == 0:
            conv = exploitability.exploitability(
                game, discounted_cfr_solver.average_policy())
            print("Iteration {} exploitability {}".format(i, conv))
Ejemplo n.º 17
0
def nxdo_snfsp_measure_exploitability_nonlstm(
        br_checkpoint_path_tuple_list: List[Tuple[str, str]],
        set_policy_weights_fn: Callable, rllib_policies: List[Policy],
        restricted_game_convertors: Union[
            List[RestrictedToBaseGameActionSpaceConverter],
            List[AgentRestrictedGameOpenSpielObsConversions]],
        poker_game_version: str):
    if poker_game_version in ["kuhn_poker", "leduc_poker"]:
        open_spiel_env_config = {"players": pyspiel.GameParameter(2)}
    else:
        open_spiel_env_config = {}

    openspiel_game = pyspiel.load_game(poker_game_version,
                                       open_spiel_env_config)

    def policy_iterable():
        for checkpoint_path_tuple in br_checkpoint_path_tuple_list:
            openspiel_policies = []
            assert isinstance(restricted_game_convertors, list)
            for player, (restricted_game_convertor,
                         player_rllib_policy) in enumerate(
                             zip(restricted_game_convertors, rllib_policies)):
                checkpoint_path = checkpoint_path_tuple[player]
                set_policy_weights_fn(player_rllib_policy, checkpoint_path)

                single_openspiel_policy = openspiel_policy_from_nonlstm_rllib_nxdo_policy(
                    openspiel_game=openspiel_game,
                    rllib_policy=player_rllib_policy,
                    restricted_game_convertor=restricted_game_convertor)
                openspiel_policies.append(single_openspiel_policy)
            yield openspiel_policies

    num_players = 2
    weights = np.ones(shape=(len(br_checkpoint_path_tuple_list),
                             num_players)) / len(br_checkpoint_path_tuple_list)

    print(f"weights: {weights}")

    avg_policies = tabular_policies_from_weighted_policies(
        game=openspiel_game,
        policy_iterable=policy_iterable(),
        weights=weights)

    print(f"avg_policies: {avg_policies}")

    nfsp_policy = JointPlayerPolicy(game=openspiel_game, policies=avg_policies)

    # Exploitability is NashConv / num_players
    if poker_game_version == "universal_poker":
        print(
            "Measuring exploitability for universal_poker policy. This will take a while..."
        )
    exploitability_result = exploitability(game=openspiel_game,
                                           policy=nfsp_policy)
    return exploitability_result
Ejemplo n.º 18
0
def run_agents(sess, env, agents, expl_policies_avg):
    agent_name = "nfsp"
    write_policy_at = [1e4, 1e5, 1e6, 3e6, 5e6]
    sess.run(tf.global_variables_initializer())
    exploit_idx = list()
    exploit_history = list()
    for ep in range(FLAGS.episodes):
        if (ep + 1) % 10000 == 0:
            expl = exploitability.exploitability(env.game, expl_policies_avg)
            exploit_idx.append(ep)
            exploit_history.append(expl)
            with open("exploitabilities.txt", "a") as f:
                f.write(str(expl) + "\n")
            losses = [agent.loss for agent in agents]
            msg = "-" * 80 + "\n"
            msg += "{}: {}\n{}\n".format(ep + 1, expl, losses)
            logging.info("%s", msg)

        if ep in write_policy_at:
            for pid, agent in enumerate(agents):
                policy_to_csv(
                    env.game, expl_policies_avg,
                    f"policies/policy_" + agent_name + "_" +
                    datetime.now().strftime("%m-%d-%Y_%H-%M") + "_" +
                    str(pid + 1) + "_" + str(ep) + "episodes.csv")

        time_step = env.reset()
        while not time_step.last():
            player_id = time_stcfr_trainep.observations["current_player"]
            agent_output = agents[player_id].step(time_step)
            action_list = [agent_output.action]
            time_step = env.step(action_list)

        # Episode is over, step all agents with final info state.
        for agent in agents:
            agent.step(time_step)

    pickle.dump([exploit_idx, exploit_history],
                open(
                    FLAGS.game + "_" + agent_name + "_" + str(FLAGS.episodes) +
                    ".dat", "wb"))

    now = datetime.now()
    for pid, agent in enumerate(agents):
        policy_to_csv(
            env.game, expl_policies_avg,
            f"policies/policy_" + now.strftime("%m-%d-%Y_%H-%M") + "_" +
            agent_name + "_" + str(pid + 1) + "_+" + str(ep) + "episodes.csv")

    plt.plot([i for i in range(len(exploit_history))], exploit_history)
    plt.ylim(0.01, 1)
    plt.yticks([1, 0.1, 0.01])
    plt.yscale("log")
    plt.xscale("log")
    plt.show()
Ejemplo n.º 19
0
def nxdo_nfsp_measure_exploitability_nonlstm(
        rllib_policies: List[Policy],
        use_delegate_policy_exploration: bool,
        restricted_game_convertors: Union[
            List[RestrictedToBaseGameActionSpaceConverter],
            List[AgentRestrictedGameOpenSpielObsConversions]],
        poker_game_version: str,
        open_spiel_env_config: dict = None):
    if open_spiel_env_config is None:
        if poker_game_version in ["kuhn_poker", "leduc_poker"]:
            open_spiel_env_config = {"players": pyspiel.GameParameter(2)}
        elif poker_game_version in ["oshi_zumo_tiny"]:
            poker_game_version = "oshi_zumo"
            open_spiel_env_config = {
                "coins": pyspiel.GameParameter(6),
                "size": pyspiel.GameParameter(2),
                "horizon": pyspiel.GameParameter(8),
            }
        else:
            open_spiel_env_config = {}

    open_spiel_env_config = {
        k: pyspiel.GameParameter(v)
        if not isinstance(v, pyspiel.GameParameter) else v
        for k, v in open_spiel_env_config.items()
    }

    openspiel_game = pyspiel.load_game(poker_game_version,
                                       open_spiel_env_config)

    opnsl_policies = []
    assert isinstance(restricted_game_convertors, list)
    for action_space_converter, rllib_policy in zip(restricted_game_convertors,
                                                    rllib_policies):
        openspiel_policy = openspiel_policy_from_nonlstm_rllib_nxdo_policy(
            openspiel_game=openspiel_game,
            rllib_policy=rllib_policy,
            restricted_game_convertor=action_space_converter,
            use_delegate_policy_exploration=use_delegate_policy_exploration)
        opnsl_policies.append(openspiel_policy)

    nfsp_policy = JointPlayerPolicy(game=openspiel_game,
                                    policies=opnsl_policies)

    # Exploitability is NashConv / num_players
    if poker_game_version == "universal_poker":
        print(
            "Measuring exploitability for universal_poker policy. This will take a while..."
        )
    exploitability_result = exploitability(game=openspiel_game,
                                           policy=nfsp_policy)
    return exploitability_result
Ejemplo n.º 20
0
def main(unused_argv):
    game = pyspiel.load_game("kuhn_poker")
    cfr_solver = cfr.CFRSolver(game)

    episodes = []
    exploits = []
    nashes = []

    # Train the agent for a specific amount of episodes
    for ep in range(FLAGS.num_train_episodes):
        print("Running episode {} of {}".format(ep, FLAGS.num_train_episodes))
        cfr_solver.evaluate_and_update_policy()
        avg_pol = cfr_solver.average_policy()

        # Calculate the exploitability and nash convergence
        expl = exploitability.exploitability(game, avg_pol)
        nash = exploitability.nash_conv(game, avg_pol)

        exploits.append(expl)
        nashes.append(nash)
        episodes.append(ep)

    # Get the average policy
    average_policy = cfr_solver.average_policy()
    average_policy_values = expected_game_score.policy_value(
        game.new_initial_state(), [average_policy] * 2)
    cur_pol = cfr_solver.current_policy()

    # Plot the exploitability
    plt.plot(episodes, exploits, "-r", label="Exploitability")
    plt.xscale("log")
    plt.yscale("log")
    plt.xlim(FLAGS.eval_every, FLAGS.num_train_episodes)
    plt.legend(loc="upper right")
    plt.show()
    plt.savefig("cfr_expl.png")

    plt.figure()

    # Plot the nash convergence
    plt.plot(episodes, nashes, "-r", label="NashConv")
    plt.xscale("log")
    plt.yscale("log")
    plt.xlim(FLAGS.eval_every, FLAGS.num_train_episodes)
    plt.legend(loc="upper right")
    plt.show()
    plt.savefig("cfr_nash.png")

    print(average_policy)
    print(average_policy_values)
    policy_to_csv(game, average_policy, "./kuhn_policy.csv")
Ejemplo n.º 21
0
def get_xdo_restricted_game_meta_Nash(game,
                                      br_list,
                                      br_conv_threshold=1e-2,
                                      seed=1):
    episode = 0
    num_infostates = 0
    start_time = time.time()
    cfr_psro_times = []
    cfr_psro_exps = []
    cfr_psro_episodes = []
    cfr_psro_infostates = []

    cfr_br_solver = cfr_br_actions.CFRSolver(game, br_list)

    for j in range(int(1e10)):
        cfr_br_solver.evaluate_and_update_policy()
        episode += 1
        if j % 50 == 0:
            br_list_conv = exploitability_br_actions.exploitability(
                game, br_list, cfr_br_solver.average_policy())
            print("Br list conv: ", br_list_conv, j)
            conv = exploitability.exploitability(
                game, cfr_br_solver.average_policy())
            print("Iteration {} exploitability {}".format(j, conv))
            elapsed_time = time.time() - start_time
            print('Total elapsed time: ', elapsed_time)
            num_infostates = cfr_br_solver.num_infostates_expanded
            print('Num infostates expanded (mil): ', num_infostates / 1e6)
            cfr_psro_times.append(elapsed_time)
            cfr_psro_exps.append(conv)
            cfr_psro_episodes.append(episode)
            cfr_psro_infostates.append(num_infostates)

            save_prefix = './results/fixed_pop/XDO/num_pop_' + str(
                len(br_list)) + '_seed_' + str(seed)
            ensure_dir(save_prefix)
            print(f"saving to: {save_prefix + '_times.npy'}")
            np.save(save_prefix + '_times.npy', np.array(cfr_psro_times))
            print(f"saving to: {save_prefix + '_exps.npy'}")
            np.save(save_prefix + '_exps.npy', np.array(cfr_psro_exps))
            print(f"saving to: {save_prefix + '_episodes.npy'}")
            np.save(save_prefix + '_episodes.npy', np.array(cfr_psro_episodes))
            print(f"saving to: {save_prefix + '_infostates.npy'}")
            np.save(save_prefix + '_infostates.npy',
                    np.array(cfr_psro_infostates))
            if br_list_conv < br_conv_threshold:
                print("Done")
                break
Ejemplo n.º 22
0
def get_psro_meta_Nash(game, br_list, num_episodes=100, seed=1):
    psro_br_list = []
    psro_br_list.append([br[0] for br in br_list])
    psro_br_list.append([1 / len(br_list) for _ in range(len(br_list))])
    psro_br_list.append([br[1] for br in br_list])
    psro_br_list.append([1 / len(br_list) for _ in range(len(br_list))])

    solver = psro_oracle.PSRO(game, psro_br_list, num_episodes=num_episodes)
    solver.evaluate()
    conv = exploitability.exploitability(game, solver._current_policy)
    save_path = './results/fixed_pop/PSRO/num_pop_' + str(
        len(br_list)) + '_seed_' + str(seed) + '_exp.npy'
    print(f"saved to: {save_path}")
    ensure_dir(save_path)
    np.save(save_path, np.array(conv))
    print("PSRO Exploitability: ", conv)
def main(_):
    game = "kuhn_poker"
    num_players = 2

    env_configs = {"players": num_players}
    env = rl_environment.Environment(game, **env_configs)
    info_state_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    with tf.Session() as sess:
        # pylint: disable=g-complex-comprehension
        agents = [
            policy_gradient.PolicyGradient(sess,
                                           idx,
                                           info_state_size,
                                           num_actions,
                                           loss_str=FLAGS.loss_str,
                                           hidden_layers_sizes=(128, ))
            for idx in range(num_players)
        ]
        expl_policies_avg = PolicyGradientPolicies(env, agents)

        sess.run(tf.global_variables_initializer())
        for ep in range(FLAGS.num_episodes):

            if (ep + 1) % FLAGS.eval_every == 0:
                losses = [agent.loss for agent in agents]
                expl = exploitability.exploitability(env.game,
                                                     expl_policies_avg)
                msg = "-" * 80 + "\n"
                msg += "{}: {}\n{}\n".format(ep + 1, expl, losses)
                logging.info("%s", msg)

            time_step = env.reset()
            while not time_step.last():
                player_id = time_step.observations["current_player"]
                agent_output = agents[player_id].step(time_step)
                action_list = [agent_output.action]
                time_step = env.step(action_list)

            # Episode is over, step all agents with final info state.
            for agent in agents:
                agent.step(time_step)

        for pid, agent in enumerate(agents):
            policy_to_csv(env.game, expl_policies_avg,
                          f"{FLAGS.modeldir}/test_p{pid+1}.csv")
Ejemplo n.º 24
0
 def _get_exploitability(self):
     tabular_policy = policy.TabularPolicy(self._game)
     for player_id in range(2):
         for info_state, state_policy in self.average_policy_tables(
         )[player_id].items():
             policy_to_update_tabular = tabular_policy.policy_for_key(
                 info_state)
             for action, probability in state_policy.items():
                 policy_to_update_tabular[action] = probability
     average_policy_values = expected_game_score.policy_value(
         self._game.new_initial_state(), [tabular_policy, tabular_policy])
     #         print("Kuhn 2P average values after %s iterations" %iters)
     #         print("P0: {}".format(average_policy_values[0]))
     #         print("P1: {}".format(average_policy_values[1]))
     exp = exploitability.exploitability(game, tabular_policy)
     print("exploitability: {}".format(exp))
     return exp
Ejemplo n.º 25
0
    def run(solver, iterations):
        start_time = time.time()
        times = []
        exps = []
        episodes = []
        cfr_infostates = []
        for i in range(iterations):
            if algorithm == 'cfr':
                solver.evaluate_and_update_policy()
            else:
                solver.iteration()
            if i % 5 == 0:
                print(algorithm)
                if algorithm == 'cfr':
                    average_policy = solver.average_policy()
                elif algorithm == 'xfp':
                    average_policy = solver.average_policy()
                elif algorithm == 'psro':
                    average_policy = solver._current_policy
                else:
                    raise ValueError(f"Unknown algorithm name: {algorithm}")

                conv = exploitability.exploitability(game, average_policy)
                print("Iteration {} exploitability {}".format(i, conv))
                elapsed_time = time.time() - start_time
                print(elapsed_time)
                times.append(elapsed_time)
                exps.append(conv)
                episodes.append(i)
                save_prefix = './results/' + algorithm + '_' + game_name + '_random_br_' + str(
                    random_max_br) + extra_info
                ensure_dir(save_prefix)
                print(f"saving to: {save_prefix + '_times.npy'}")
                np.save(save_prefix + '_times', np.array(times))
                print(f"saving to: {save_prefix + '_exps.npy'}")
                np.save(save_prefix + '_exps', np.array(exps))
                print(f"saving to: {save_prefix + '_episodes.npy'}")
                np.save(save_prefix + '_episodes', np.array(episodes))
                if algorithm == 'cfr':
                    cfr_infostates.append(solver.num_infostates_expanded)
                    print("Num infostates expanded (mil): ",
                          solver.num_infostates_expanded / 1e6)
                    print(f"saving to: {save_prefix + '_infostates.npy'}")
                    np.save(save_prefix + '_infostates',
                            np.array(cfr_infostates))
def measure_exploitability_nonlstm(rllib_policy,
                                   poker_game_version,
                                   policy_mixture_dict=None,
                                   set_policy_weights_fn=None):
    if poker_game_version in [KUHN_POKER, LEDUC_POKER]:
        open_spiel_env_config = {"players": pyspiel.GameParameter(2)}
    else:
        open_spiel_env_config = {}

    openspiel_game = pyspiel.load_game(poker_game_version,
                                       open_spiel_env_config)

    if policy_mixture_dict is None:
        openspiel_policy = openspiel_policy_from_nonlstm_rllib_policy(
            openspiel_game=openspiel_game,
            poker_game_version=poker_game_version,
            rllib_policy=rllib_policy)
    else:
        if set_policy_weights_fn is None:
            raise ValueError(
                "If policy_mixture_dict is passed a value, a set_policy_weights_fn must be passed as well."
            )

        def policy_iterable():
            for weights_key in policy_mixture_dict.keys():
                set_policy_weights_fn(weights_key)

                single_openspiel_policy = openspiel_policy_from_nonlstm_rllib_policy(
                    openspiel_game=openspiel_game,
                    poker_game_version=poker_game_version,
                    rllib_policy=rllib_policy)
                yield single_openspiel_policy

        openspiel_policy = tabular_policy_from_weighted_policies(
            game=openspiel_game,
            policy_iterable=policy_iterable(),
            weights=policy_mixture_dict.values())
    # Exploitability is NashConv / num_players
    exploitability_result = exploitability(game=openspiel_game,
                                           policy=openspiel_policy)
    return exploitability_result
Ejemplo n.º 27
0
        num_infostates = 0
        for i in range(iterations):
            print('Iteration: ', i)
            cfr_br_solver = cfr_br_actions.CFRSolver(game, br_list)

            for j in range(xdo_iterations):
                cfr_br_solver.evaluate_and_update_policy()
                episode += 1
                if j % 50 == 0:
                    br_list_conv = exploitability_br_actions.exploitability(
                        game, br_list, cfr_br_solver.average_policy())
                    print("Br list conv: ", br_list_conv, j)
                    if br_list_conv < br_conv_threshold:
                        break

            conv = exploitability.exploitability(
                game, cfr_br_solver.average_policy())
            print("Iteration {} exploitability {}".format(i, conv))
            if conv < br_conv_threshold:
                br_conv_threshold /= 2
                print("new br threshold: ", br_conv_threshold)

            elapsed_time = time.time() - start_time
            print('Total elapsed time: ', elapsed_time)
            num_infostates += cfr_br_solver.num_infostates_expanded
            print('Num infostates expanded (mil): ', num_infostates / 1e6)
            xdo_times.append(elapsed_time)
            xdo_exps.append(conv)
            xdo_episodes.append(episode)
            xdo_infostates.append(num_infostates)

            brs = []
Ejemplo n.º 28
0
def psro_measure_exploitability_nonlstm(
        br_checkpoint_path_tuple_list: List[Tuple[str, str]],
        metanash_weights: List[Tuple[float, float]],
        set_policy_weights_fn: Callable,
        rllib_policies: List[Policy],
        poker_game_version: str,
        open_spiel_env_config: dict = None):
    if open_spiel_env_config is None:
        if poker_game_version in ["kuhn_poker", "leduc_poker"]:
            open_spiel_env_config = {"players": pyspiel.GameParameter(2)}
        else:
            open_spiel_env_config = {}

    open_spiel_env_config = {
        k: pyspiel.GameParameter(v)
        if not isinstance(v, pyspiel.GameParameter) else v
        for k, v in open_spiel_env_config.items()
    }

    openspiel_game = pyspiel.load_game(poker_game_version,
                                       open_spiel_env_config)
    if poker_game_version == "oshi_zumo":
        openspiel_game = pyspiel.convert_to_turn_based(openspiel_game)

    def policy_iterable():
        for checkpoint_path_tuple in br_checkpoint_path_tuple_list:
            openspiel_policies = []
            for player, player_rllib_policy in enumerate(rllib_policies):
                checkpoint_path = checkpoint_path_tuple[player]
                if checkpoint_path not in _psro_tabular_policies_cache:
                    set_policy_weights_fn(player_rllib_policy,
                                          checkpoint_path=checkpoint_path)
                    single_openspiel_policy = openspiel_policy_from_nonlstm_rllib_policy(
                        openspiel_game=openspiel_game,
                        rllib_policy=player_rllib_policy,
                        game_version=poker_game_version,
                        game_parameters=open_spiel_env_config,
                    )
                    if CACHE_PSRO_TABULAR_POLICIES:
                        _psro_tabular_policies_cache[
                            checkpoint_path] = single_openspiel_policy
                else:
                    single_openspiel_policy = _psro_tabular_policies_cache[
                        checkpoint_path]

                openspiel_policies.append(single_openspiel_policy)
            yield openspiel_policies

    avg_policies = tabular_policies_from_weighted_policies(
        game=openspiel_game,
        policy_iterable=policy_iterable(),
        weights=metanash_weights)

    joint_player_policy = JointPlayerPolicy(game=openspiel_game,
                                            policies=avg_policies)

    # Exploitability is NashConv / num_players
    if poker_game_version == "universal_poker":
        print(
            "Measuring exploitability for universal_poker policy. This will take a while..."
        )
    exploitability_result = exploitability(game=openspiel_game,
                                           policy=joint_player_policy)
    return exploitability_result
def main(_):
  game = pyspiel.load_game(FLAGS.game)
  expl = exploitability.exploitability(game, policy.UniformRandomPolicy(game))
  print("Exploitability: {}".format(expl))
Ejemplo n.º 30
0
def main(unused_argv):
    logging.info("Loading %s", FLAGS.game_name)
    game = FLAGS.game_name
    num_players = FLAGS.num_players

    env_configs = {"players": num_players}
    env = rl_environment.Environment(game, **env_configs)
    info_state_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes]
    kwargs = {
        "replay_buffer_capacity": FLAGS.replay_buffer_capacity,
        "reservoir_buffer_capacity": FLAGS.reservoir_buffer_capacity,
        "min_buffer_size_to_learn": FLAGS.min_buffer_size_to_learn,
        "anticipatory_param": FLAGS.anticipatory_param,
        "batch_size": FLAGS.batch_size,
        "learn_every": FLAGS.learn_every,
        "rl_learning_rate": FLAGS.rl_learning_rate,
        "sl_learning_rate": FLAGS.sl_learning_rate,
        "optimizer_str": FLAGS.optimizer_str,
        "loss_str": FLAGS.loss_str,
        "update_target_network_every": FLAGS.update_target_network_every,
        "discount_factor": FLAGS.discount_factor,
        "epsilon_decay_duration": FLAGS.epsilon_decay_duration,
        "epsilon_start": FLAGS.epsilon_start,
        "epsilon_end": FLAGS.epsilon_end,
    }

    with tf.Session() as sess:
        # pylint: disable=g-complex-comprehension
        agents = [
            nfsp.NFSP(sess, idx, info_state_size, num_actions,
                      hidden_layers_sizes, **kwargs)
            for idx in range(num_players)
        ]
        joint_avg_policy = NFSPPolicies(env, agents, nfsp.MODE.average_policy)

        sess.run(tf.global_variables_initializer())

        if FLAGS.use_checkpoints:
            for agent in agents:
                if agent.has_checkpoint(FLAGS.checkpoint_dir):
                    agent.restore(FLAGS.checkpoint_dir)

        for ep in range(FLAGS.num_train_episodes):
            if (ep + 1) % FLAGS.eval_every == 0:
                losses = [agent.loss for agent in agents]
                logging.info("Losses: %s", losses)
                if FLAGS.evaluation_metric == "exploitability":
                    # Avg exploitability is implemented only for 2 players constant-sum
                    # games, use nash_conv otherwise.
                    expl = exploitability.exploitability(
                        env.game, joint_avg_policy)
                    logging.info("[%s] Exploitability AVG %s", ep + 1, expl)
                elif FLAGS.evaluation_metric == "nash_conv":
                    nash_conv = exploitability.nash_conv(
                        env.game, joint_avg_policy)
                    logging.info("[%s] NashConv %s", ep + 1, nash_conv)
                else:
                    raise ValueError(" ".join(
                        ("Invalid evaluation metric, choose from",
                         "'exploitability', 'nash_conv'.")))
                if FLAGS.use_checkpoints:
                    for agent in agents:
                        agent.save(FLAGS.checkpoint_dir)
                logging.info("_____________________________________________")

            time_step = env.reset()
            while not time_step.last():
                player_id = time_step.observations["current_player"]
                agent_output = agents[player_id].step(time_step)
                action_list = [agent_output.action]
                time_step = env.step(action_list)

            # Episode is over, step all agents with final info state.
            for agent in agents:
                agent.step(time_step)