def test_train_pong(self): hparams = registry.hparams("mfrl_original") hparams.batch_size = 2 hparams.ppo_epochs_num = 2 hparams.ppo_epoch_length = 3 FLAGS.output_dir = tf.test.get_temp_dir() trainer_model_free.train(hparams, FLAGS.output_dir)
def test_train_pong(self): hparams = registry.hparams("rlmf_original") hparams.batch_size = 2 hparams.eval_sampling_temps = [0.0, 1.0] hparams.add_hparam("ppo_epochs_num", 2) hparams.add_hparam("ppo_epoch_length", 3) FLAGS.output_dir = tf.test.get_temp_dir() trainer_model_free.train(hparams, FLAGS.output_dir)
def test_train_tictactoe(self): hparams = registry.hparams("rlmf_tictactoe") hparams.batch_size = 2 hparams.eval_sampling_temps = [0.0, 1.0] hparams.add_hparam("ppo_epochs_num", 2) hparams.add_hparam("ppo_epoch_length", 3) hparams.epochs_num = 100 hparams.eval_every_epochs = 25 FLAGS.output_dir = tf.test.get_temp_dir() trainer_model_free.train(hparams, FLAGS.output_dir)
def _test_hparams_set(self, hparams_set): hparams = registry.hparams(hparams_set) FLAGS.output_dir = tf.test.get_temp_dir() trainer_model_free.train(hparams, FLAGS.output_dir, env_problem_name=None)