示例#1
0
def _new_model():
  return rcfr.DeepRcfrModel(
      _GAME,
      num_hidden_layers=1,
      num_hidden_units=13,
      num_hidden_factors=1,
      use_skip_connections=True)
def rcfr_train(unused_arg):
    tf.enable_eager_execution()
    game = pyspiel.load_game(FLAGS.game, {"players": pyspiel.GameParameter(2)})
    models = [
        rcfr.DeepRcfrModel(
            game,
            num_hidden_layers=1,
            num_hidden_units=64 if FLAGS.game == "leduc_poker" else 13,
            num_hidden_factors=1,
            use_skip_connections=True) for _ in range(game.num_players())
    ]
    patient = rcfr.RcfrSolver(game, models, False, True)
    exploit_history = list()
    exploit_idx = list()

    def _train(model, data):
        data = data.shuffle(1000)
        data = data.batch(12)
        #data = data.repeat(1)
        optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True)
        for x, y in data:
            optimizer.minimize(
                lambda: tf.losses.huber_loss(y, model(x)),  # pylint: disable=cell-var-from-loop
                model.trainable_variables)

    agent_name = "rcfr"
    checkpoint = datetime.now()
    for iteration in range(FLAGS.episodes):
        if (iteration % 100) == 0:
            delta = datetime.now() - checkpoint
            conv = pyspiel.exploitability(game, patient.average_policy())
            exploit_idx.append(iteration)
            exploit_history.append(conv)
            print(
                "[RCFR] Iteration {} exploitability {} - {} seconds since last checkpoint"
                .format(iteration, conv, delta.seconds))
            checkpoint = datetime.now()
        patient.evaluate_and_update_policy(_train)

    pickle.dump([exploit_idx, exploit_history],
                open(
                    FLAGS.game + "_" + agent_name + "_" + str(FLAGS.episodes) +
                    ".dat", "wb"))

    now = datetime.now()
    policy = patient.average_policy()

    for pid in [1, 2]:
        policy_to_csv(
            game, policy, f"policies/policy_" +
            now.strftime("%m-%d-%Y_%H-%M") + "_" + agent_name + "_" +
            str(pid + 1) + "_+" + str(FLAGS.episodes) + "episodes.csv")
def main(_):
    game = pyspiel.load_game(FLAGS.game,
                             {"players": pyspiel.GameParameter(FLAGS.players)})

    models = []
    for _ in range(game.num_players()):
        models.append(
            rcfr.DeepRcfrModel(
                game,
                num_hidden_layers=FLAGS.num_hidden_layers,
                num_hidden_units=FLAGS.num_hidden_units,
                num_hidden_factors=FLAGS.num_hidden_factors,
                use_skip_connections=FLAGS.use_skip_connections))

    if FLAGS.buffer_size > 0:
        solver = rcfr.ReservoirRcfrSolver(
            game,
            models,
            FLAGS.buffer_size,
            truncate_negative=FLAGS.truncate_negative)
    else:
        solver = rcfr.RcfrSolver(game,
                                 models,
                                 truncate_negative=FLAGS.truncate_negative,
                                 bootstrap=FLAGS.bootstrap)

    def _train_fn(model, data):
        """Train `model` on `data`."""
        data = data.shuffle(FLAGS.batch_size * 10)
        data = data.batch(FLAGS.batch_size)
        data = data.repeat(FLAGS.num_epochs)

        optimizer = tf.keras.optimizers.Adam(lr=FLAGS.step_size, amsgrad=True)

        @tf.function
        def _train():
            for x, y in data:
                optimizer.minimize(
                    lambda: tf.compat.v1.losses.huber_loss(
                        y, model(x), delta=0.01),  # pylint: disable=cell-var-from-loop
                    model.trainable_variables)

        _train()

    # End of _train_fn

    for i in range(FLAGS.iterations):
        solver.evaluate_and_update_policy(_train_fn)
        if i % FLAGS.print_freq == 0:
            conv = pyspiel.exploitability(game, solver.average_policy())
            print("Iteration {} exploitability {}".format(i, conv))