コード例 #1
0
def reps_gym_cartpole(ctxt=None, seed=1):
    """Train REPS with CartPole-v0 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        env = GarageEnv(gym.make('CartPole-v0'))

        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32])

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = REPS(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99)

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000, plot=False)
コード例 #2
0
ファイル: reps_gym_cartpole.py プロジェクト: sumeromer/garage
def reps_gym_cartpole(ctxt=None, seed=1):
    """Train REPS with CartPole-v0 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    with TFTrainer(snapshot_config=ctxt) as trainer:
        env = GymEnv('CartPole-v0')

        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32])

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        sampler = RaySampler(agents=policy,
                             envs=env,
                             max_episode_length=env.spec.max_episode_length,
                             is_tf_worker=True)

        algo = REPS(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    sampler=sampler,
                    discount=0.99)

        trainer.setup(algo, env)
        trainer.train(n_epochs=100, batch_size=4000, plot=False)
コード例 #3
0
ファイル: test_reps.py プロジェクト: XavierJingfeng/starter
    def test_reps_cartpole(self):
        """Test REPS with gym Cartpole environment."""
        with LocalRunner(self.sess) as runner:
            env = TfEnv(gym.make('CartPole-v0'))

            policy = CategoricalMLPPolicy(env_spec=env.spec,
                                          hidden_sizes=[32, 32])

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = REPS(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        batch_size=4000,
                        max_path_length=100,
                        n_itr=10,
                        discount=0.99,
                        max_kl_step=1e6)

            runner.setup(algo, env)

            last_avg_ret = runner.train(n_epochs=10, batch_size=4000)
            assert last_avg_ret > 5

            env.close()
コード例 #4
0
ファイル: test_reps.py プロジェクト: ziyiwu9494/garage
    def test_reps_cartpole(self):
        """Test REPS with gym Cartpole environment."""
        with TFTrainer(snapshot_config, sess=self.sess) as trainer:
            env = GymEnv('CartPole-v0')

            policy = CategoricalMLPPolicy(env_spec=env.spec,
                                          hidden_sizes=[32, 32])

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            sampler = LocalSampler(
                agents=policy,
                envs=env,
                max_episode_length=env.spec.max_episode_length,
                is_tf_worker=True)

            algo = REPS(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        sampler=sampler,
                        discount=0.99)

            trainer.setup(algo, env)

            last_avg_ret = trainer.train(n_epochs=10, batch_size=4000)
            assert last_avg_ret > 5

            env.close()
コード例 #5
0
ファイル: reps_gym_cartpole.py プロジェクト: Kelvinson/garage
def run_task(*_):
    """Wrap REPS training task in the run_task function."""
    env = TfEnv(gym.make("CartPole-v0"))

    policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32])

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = REPS(env=env,
                policy=policy,
                baseline=baseline,
                batch_size=4000,
                max_path_length=100,
                n_itr=100,
                discount=0.99,
                plot=False)

    algo.train()
コード例 #6
0
    def test_reps_cartpole(self):
        """Test REPS with gym Cartpole environment."""
        logger.reset()
        env = TfEnv(gym.make("CartPole-v0"))

        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32])

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = REPS(env=env,
                    policy=policy,
                    baseline=baseline,
                    batch_size=4000,
                    max_path_length=100,
                    n_itr=10,
                    discount=0.99,
                    max_kl_step=1e6,
                    plot=False)

        last_avg_ret = algo.train(sess=self.sess)
        assert last_avg_ret > 5

        env.close()
コード例 #7
0
def run_task(*_):
    with LocalRunner() as runner:
        env = TfEnv(gym.make("CartPole-v0"))

        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32])

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = REPS(env=env,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99)

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000, plot=False)
コード例 #8
0
ファイル: reps_gym_cartpole.py プロジェクト: JoleProject/Jole
def run_task(snapshot_config, *_):
    """Run task."""
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(gym.make('CartPole-v0'))

        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32])

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = REPS(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99)

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000, plot=False)
コード例 #9
0
    def test_reps_cartpole(self):
        """Test REPS with gym Cartpole environment."""
        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            env = GymEnv('CartPole-v0')

            policy = CategoricalMLPPolicy(env_spec=env.spec,
                                          hidden_sizes=[32, 32])

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = REPS(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        discount=0.99)

            runner.setup(algo, env, sampler_cls=LocalSampler)

            last_avg_ret = runner.train(n_epochs=10, batch_size=4000)
            assert last_avg_ret > 5

            env.close()
コード例 #10
0
    def test_on_policy_vectorized_sampler_n_envs(self, cpus, n_envs,
                                                 expected_n_envs):
        with LocalTFRunner(snapshot_config, sess=self.sess,
                           max_cpus=cpus) as runner:
            env = TfEnv(gym.make('CartPole-v0'))

            policy = CategoricalMLPPolicy(env_spec=env.spec,
                                          hidden_sizes=[32, 32])

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = REPS(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        max_path_length=100,
                        discount=0.99)

            runner.setup(algo, env, sampler_args=dict(n_envs=n_envs))

            assert isinstance(runner._sampler, OnPolicyVectorizedSampler)
            assert runner._sampler._n_envs == expected_n_envs

            env.close()