def run_task(*_): """Wrap REPS training task in the run_task function.""" env = TfEnv(gym.make("CartPole-v0")) policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32]) baseline = LinearFeatureBaseline(env_spec=env.spec) algo = REPS(env=env, policy=policy, baseline=baseline, batch_size=4000, max_path_length=100, n_itr=100, discount=0.99, plot=False) algo.train()
def test_reps_cartpole(self): """Test REPS with gym Cartpole environment.""" logger.reset() env = TfEnv(gym.make("CartPole-v0")) policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32]) baseline = LinearFeatureBaseline(env_spec=env.spec) algo = REPS(env=env, policy=policy, baseline=baseline, batch_size=4000, max_path_length=100, n_itr=10, discount=0.99, max_kl_step=1e6, plot=False) last_avg_ret = algo.train(sess=self.sess) assert last_avg_ret > 5 env.close()