def test_train_pong(self):
     hparams = registry.hparams("mfrl_original")
     hparams.batch_size = 2
     hparams.ppo_epochs_num = 2
     hparams.ppo_epoch_length = 3
     FLAGS.output_dir = tf.test.get_temp_dir()
     trainer_model_free.train(hparams, FLAGS.output_dir)
Пример #2
0
 def test_train_pong(self):
   hparams = registry.hparams("rlmf_original")
   hparams.batch_size = 2
   hparams.eval_sampling_temps = [0.0, 1.0]
   hparams.add_hparam("ppo_epochs_num", 2)
   hparams.add_hparam("ppo_epoch_length", 3)
   FLAGS.output_dir = tf.test.get_temp_dir()
   trainer_model_free.train(hparams, FLAGS.output_dir)
    def test_train_tictactoe(self):
        hparams = registry.hparams("rlmf_tictactoe")
        hparams.batch_size = 2
        hparams.eval_sampling_temps = [0.0, 1.0]
        hparams.add_hparam("ppo_epochs_num", 2)
        hparams.add_hparam("ppo_epoch_length", 3)

        hparams.epochs_num = 100
        hparams.eval_every_epochs = 25

        FLAGS.output_dir = tf.test.get_temp_dir()
        trainer_model_free.train(hparams, FLAGS.output_dir)
 def _test_hparams_set(self, hparams_set):
   hparams = registry.hparams(hparams_set)
   FLAGS.output_dir = tf.test.get_temp_dir()
   trainer_model_free.train(hparams, FLAGS.output_dir,
                            env_problem_name=None)