def _new_model(): return rcfr.DeepRcfrModel( _GAME, num_hidden_layers=1, num_hidden_units=13, num_hidden_factors=1, use_skip_connections=True)
def rcfr_train(unused_arg): tf.enable_eager_execution() game = pyspiel.load_game(FLAGS.game, {"players": pyspiel.GameParameter(2)}) models = [ rcfr.DeepRcfrModel( game, num_hidden_layers=1, num_hidden_units=64 if FLAGS.game == "leduc_poker" else 13, num_hidden_factors=1, use_skip_connections=True) for _ in range(game.num_players()) ] patient = rcfr.RcfrSolver(game, models, False, True) exploit_history = list() exploit_idx = list() def _train(model, data): data = data.shuffle(1000) data = data.batch(12) #data = data.repeat(1) optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True) for x, y in data: optimizer.minimize( lambda: tf.losses.huber_loss(y, model(x)), # pylint: disable=cell-var-from-loop model.trainable_variables) agent_name = "rcfr" checkpoint = datetime.now() for iteration in range(FLAGS.episodes): if (iteration % 100) == 0: delta = datetime.now() - checkpoint conv = pyspiel.exploitability(game, patient.average_policy()) exploit_idx.append(iteration) exploit_history.append(conv) print( "[RCFR] Iteration {} exploitability {} - {} seconds since last checkpoint" .format(iteration, conv, delta.seconds)) checkpoint = datetime.now() patient.evaluate_and_update_policy(_train) pickle.dump([exploit_idx, exploit_history], open( FLAGS.game + "_" + agent_name + "_" + str(FLAGS.episodes) + ".dat", "wb")) now = datetime.now() policy = patient.average_policy() for pid in [1, 2]: policy_to_csv( game, policy, f"policies/policy_" + now.strftime("%m-%d-%Y_%H-%M") + "_" + agent_name + "_" + str(pid + 1) + "_+" + str(FLAGS.episodes) + "episodes.csv")
def main(_): game = pyspiel.load_game(FLAGS.game, {"players": pyspiel.GameParameter(FLAGS.players)}) models = [] for _ in range(game.num_players()): models.append( rcfr.DeepRcfrModel( game, num_hidden_layers=FLAGS.num_hidden_layers, num_hidden_units=FLAGS.num_hidden_units, num_hidden_factors=FLAGS.num_hidden_factors, use_skip_connections=FLAGS.use_skip_connections)) if FLAGS.buffer_size > 0: solver = rcfr.ReservoirRcfrSolver( game, models, FLAGS.buffer_size, truncate_negative=FLAGS.truncate_negative) else: solver = rcfr.RcfrSolver(game, models, truncate_negative=FLAGS.truncate_negative, bootstrap=FLAGS.bootstrap) def _train_fn(model, data): """Train `model` on `data`.""" data = data.shuffle(FLAGS.batch_size * 10) data = data.batch(FLAGS.batch_size) data = data.repeat(FLAGS.num_epochs) optimizer = tf.keras.optimizers.Adam(lr=FLAGS.step_size, amsgrad=True) @tf.function def _train(): for x, y in data: optimizer.minimize( lambda: tf.compat.v1.losses.huber_loss( y, model(x), delta=0.01), # pylint: disable=cell-var-from-loop model.trainable_variables) _train() # End of _train_fn for i in range(FLAGS.iterations): solver.evaluate_and_update_policy(_train_fn) if i % FLAGS.print_freq == 0: conv = pyspiel.exploitability(game, solver.average_policy()) print("Iteration {} exploitability {}".format(i, conv))