def test_cem_cartpole(self): """Test CEM with Cartpole-v1 environment.""" with LocalTFRunner(snapshot_config) as runner: env = GymEnv('CartPole-v1') policy = CategoricalMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) baseline = LinearFeatureBaseline(env_spec=env.spec) n_samples = 10 algo = CEM(env_spec=env.spec, policy=policy, baseline=baseline, best_frac=0.1, n_samples=n_samples) runner.setup(algo, env, sampler_cls=LocalSampler) rtn = runner.train(n_epochs=10, batch_size=2048) assert rtn > 40 env.close()
def run_task(snapshot_config, *_): """Train CEM""" with LocalTFRunner(snapshot_config=snapshot_config) as runner: env = TfEnv(env_name='Swimmer-v2') policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) baseline = LinearFeatureBaseline(env_spec=env.spec) n_samples = 20 algo = CEM(env_spec=env.spec, policy=policy, baseline=baseline, best_frac=0.05, max_path_length=100, n_samples=n_samples) runner.setup(algo, env, sampler_cls=OnPolicyVectorizedSampler) # NOTE: make sure that n_epoch_cycles == n_samples ! runner.train(n_epochs=100, batch_size=1000, n_epoch_cycles=n_samples)