コード例 #1
0
ファイル: test_cem.py プロジェクト: geyang/garage
    def test_cem_cartpole(self):
        """Test CEM with Cartpole-v1 environment."""
        with LocalTFRunner(snapshot_config) as runner:
            env = GymEnv('CartPole-v1')

            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          hidden_sizes=(32, 32))
            baseline = LinearFeatureBaseline(env_spec=env.spec)

            n_samples = 10

            algo = CEM(env_spec=env.spec,
                       policy=policy,
                       baseline=baseline,
                       best_frac=0.1,
                       n_samples=n_samples)

            runner.setup(algo, env, sampler_cls=LocalSampler)
            rtn = runner.train(n_epochs=10, batch_size=2048)
            assert rtn > 40

            env.close()
コード例 #2
0
def run_task(snapshot_config, *_):
    """Train CEM"""
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(env_name='Swimmer-v2')

        policy = GaussianMLPPolicy(name='policy',
                                   env_spec=env.spec,
                                   hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        n_samples = 20

        algo = CEM(env_spec=env.spec,
                   policy=policy,
                   baseline=baseline,
                   best_frac=0.05,
                   max_path_length=100,
                   n_samples=n_samples)

        runner.setup(algo, env, sampler_cls=OnPolicyVectorizedSampler)
        # NOTE: make sure that n_epoch_cycles == n_samples !
        runner.train(n_epochs=100, batch_size=1000, n_epoch_cycles=n_samples)