def test_rcfr_with_buffer(self): buffer_size = 12 num_epochs = 100 num_iterations = 2 models = [_new_model() for _ in range(_GAME.num_players())] patient = rcfr.ReservoirRcfrSolver(_GAME, models, buffer_size=buffer_size) def _train(model, data): data = data.shuffle(12) data = data.batch(12) data = data.repeat(num_epochs) optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True) for x, y in data: optimizer.minimize( lambda: tf.losses.huber_loss(y, model(x)), # pylint: disable=cell-var-from-loop model.trainable_variables) average_policy = patient.average_policy() self.assertGreater(pyspiel.nash_conv(_GAME, average_policy), 0.91) for _ in range(num_iterations): patient.evaluate_and_update_policy(_train) average_policy = patient.average_policy() self.assertLess(pyspiel.nash_conv(_GAME, average_policy), 0.91)
def main(_): game = pyspiel.load_game(FLAGS.game, {"players": pyspiel.GameParameter(FLAGS.players)}) models = [] for _ in range(game.num_players()): models.append( rcfr.DeepRcfrModel( game, num_hidden_layers=FLAGS.num_hidden_layers, num_hidden_units=FLAGS.num_hidden_units, num_hidden_factors=FLAGS.num_hidden_factors, use_skip_connections=FLAGS.use_skip_connections)) if FLAGS.buffer_size > 0: solver = rcfr.ReservoirRcfrSolver( game, models, FLAGS.buffer_size, truncate_negative=FLAGS.truncate_negative) else: solver = rcfr.RcfrSolver(game, models, truncate_negative=FLAGS.truncate_negative, bootstrap=FLAGS.bootstrap) def _train_fn(model, data): """Train `model` on `data`.""" data = data.shuffle(FLAGS.batch_size * 10) data = data.batch(FLAGS.batch_size) data = data.repeat(FLAGS.num_epochs) optimizer = tf.keras.optimizers.Adam(lr=FLAGS.step_size, amsgrad=True) @tf.function def _train(): for x, y in data: optimizer.minimize( lambda: tf.compat.v1.losses.huber_loss( y, model(x), delta=0.01), # pylint: disable=cell-var-from-loop model.trainable_variables) _train() # End of _train_fn for i in range(FLAGS.iterations): solver.evaluate_and_update_policy(_train_fn) if i % FLAGS.print_freq == 0: conv = pyspiel.exploitability(game, solver.average_policy()) print("Iteration {} exploitability {}".format(i, conv))