def main(_): game = pyspiel.load_game(FLAGS.game, {"players": pyspiel.GameParameter(FLAGS.players)}) models = [] for _ in range(game.num_players()): models.append( neurd.DeepNeurdModel( game, num_hidden_layers=FLAGS.num_hidden_layers, num_hidden_units=FLAGS.num_hidden_units, num_hidden_factors=FLAGS.num_hidden_factors, use_skip_connections=FLAGS.use_skip_connections, autoencode=FLAGS.autoencode)) solver = neurd.CounterfactualNeurdSolver(game, models) def _train(model, data): neurd.train(model, data, batch_size=FLAGS.batch_size, step_size=FLAGS.step_size, threshold=FLAGS.threshold, autoencoder_loss=(tf.compat.v1.losses.huber_loss if FLAGS.autoencode else None)) for i in range(FLAGS.iterations): solver.evaluate_and_update_policy(_train) if i % FLAGS.print_freq == 0: conv = pyspiel.exploitability(game, solver.average_policy()) print("Iteration {} exploitability {}".format(i, conv))
def _new_model(): return neurd.DeepNeurdModel(_GAME, num_hidden_layers=1, num_hidden_units=13, num_hidden_factors=1, use_skip_connections=True, autoencode=True)
def neurd_train(unudes_arg): tf.enable_eager_execution() game = pyspiel.load_game(FLAGS.game, {"players": pyspiel.GameParameter(2)}) models = [] for _ in range(game.num_players()): models.append( neurd.DeepNeurdModel(game, num_hidden_layers=1, num_hidden_units=13, num_hidden_factors=8, use_skip_connections=True, autoencode=False)) solver = neurd.CounterfactualNeurdSolver(game, models) def _train(model, data): neurd.train(model, data, batch_size=100, step_size=1, threshold=2, autoencoder_loss=(None)) exploit_history = list() for ep in range(FLAGS.episodes): solver.evaluate_and_update_policy(_train) if ep % 100 == 0: conv = pyspiel.exploitability(game, solver.average_policy()) exploit_history.append(conv) print("Iteration {} exploitability {}".format(ep, conv)) now = datetime.now() policy = solver.average_policy() agent_name = "neurd" for pid in [1, 2]: policy_to_csv( game, policy, f"policies/policy_" + now.strftime("%m-%d-%Y_%H-%M") + "_" + agent_name + "_" + str(pid + 1) + "_+" + str(ep) + "episodes.csv") plt.plot([i for i in range(len(exploit_history))], exploit_history) plt.ylim(0.01, 1) plt.yticks([1, 0.1, 0.01]) plt.yscale("log") plt.xscale("log") plt.show()