示例#1
0
    def test_erwr_cartpole(self):
        """Test ERWR with Cartpole-v1 environment."""
        with TFTrainer(snapshot_config, sess=self.sess) as trainer:
            deterministic.set_seed(1)
            env = GymEnv('CartPole-v1')

            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          hidden_sizes=(32, 32))

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            sampler = LocalSampler(
                agents=policy,
                envs=env,
                max_episode_length=env.spec.max_episode_length,
                is_tf_worker=True)

            algo = ERWR(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        sampler=sampler,
                        discount=0.99)

            trainer.setup(algo, env)

            last_avg_ret = trainer.train(n_epochs=10, batch_size=10000)
            assert last_avg_ret > 60

            env.close()
示例#2
0
def erwr_cartpole(ctxt=None, seed=1):
    """Train with ERWR on CartPole-v1 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        env = GymEnv('CartPole-v1')

        policy = CategoricalMLPPolicy(name='policy',
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = ERWR(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_episode_length=100,
                    discount=0.99)

        runner.setup(algo=algo, env=env)

        runner.train(n_epochs=100, batch_size=10000, plot=False)
示例#3
0
def run_task(*_):
    """Wrap ERWR training task in the run_task function."""
    env = TfEnv(normalize(CartpoleEnv()))

    policy = GaussianMLPPolicy(name="policy",
                               env_spec=env.spec,
                               hidden_sizes=(32, 32))

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = ERWR(env=env,
                policy=policy,
                baseline=baseline,
                batch_size=10000,
                max_path_length=100,
                n_itr=40,
                discount=0.99)
    algo.train()
示例#4
0
def run_task(*_):
    """Wrap ERWR training task in the run_task function."""
    env = TfEnv(env_name="CartPole-v1")

    policy = CategoricalMLPPolicy(name="policy",
                                  env_spec=env.spec,
                                  hidden_sizes=(32, 32))

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = ERWR(env=env,
                policy=policy,
                baseline=baseline,
                batch_size=10000,
                max_path_length=100,
                n_itr=100,
                plot=True,
                discount=0.99)
    algo.train()
示例#5
0
    def test_erwr_cartpole(self):
        """Test ERWR with Cartpole environment."""
        logger.reset()
        env = TfEnv(normalize(CartpoleEnv()))

        policy = GaussianMLPPolicy(
            name="policy", env_spec=env.spec, hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = ERWR(
            env=env,
            policy=policy,
            baseline=baseline,
            batch_size=10000,
            max_path_length=100,
            n_itr=10,
            discount=0.99)

        last_avg_ret = algo.train(sess=self.sess)
        assert last_avg_ret > 100
示例#6
0
    def test_erwr_cartpole(self):
        """Test ERWR with Cartpole-v1 environment."""
        env = TfEnv(env_name="CartPole-v1")

        policy = CategoricalMLPPolicy(name="policy",
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = ERWR(env=env,
                    policy=policy,
                    baseline=baseline,
                    batch_size=10000,
                    max_path_length=100,
                    n_itr=10,
                    discount=0.99)

        last_avg_ret = algo.train(sess=self.sess)
        assert last_avg_ret > 80

        env.close()
示例#7
0
def run_task(*_):
    with LocalRunner() as runner:
        env = TfEnv(env_name="CartPole-v1")

        policy = CategoricalMLPPolicy(name="policy",
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = ERWR(env=env,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99)

        runner.setup(algo=algo, env=env)

        runner.train(n_epochs=100, batch_size=10000, plot=True)
示例#8
0
def run_task(snapshot_config, *_):
    """Run task."""
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(env_name='CartPole-v1')

        policy = CategoricalMLPPolicy(name='policy',
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = ERWR(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99)

        runner.setup(algo=algo, env=env)

        runner.train(n_epochs=100, batch_size=10000, plot=True)
示例#9
0
    def test_erwr_cartpole(self):
        """Test ERWR with Cartpole-v1 environment."""
        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            deterministic.set_seed(1)
            env = GymEnv('CartPole-v1')

            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          hidden_sizes=(32, 32))

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = ERWR(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        discount=0.99)

            runner.setup(algo, env, sampler_cls=LocalSampler)

            last_avg_ret = runner.train(n_epochs=10, batch_size=10000)
            assert last_avg_ret > 60

            env.close()
示例#10
0
    def test_erwr_cartpole(self):
        """Test ERWR with Cartpole-v1 environment."""
        with LocalRunner(self.sess) as runner:
            env = TfEnv(env_name="CartPole-v1")

            policy = CategoricalMLPPolicy(name="policy",
                                          env_spec=env.spec,
                                          hidden_sizes=(32, 32))

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = ERWR(env=env,
                        policy=policy,
                        baseline=baseline,
                        max_path_length=100,
                        discount=0.99)

            runner.setup(algo, env)

            last_avg_ret = runner.train(n_epochs=10, batch_size=10000)
            assert last_avg_ret > 80

            env.close()