def trpo_swimmer(ctxt=None, seed=1, batch_size=4000):
    """Train TRPO with Swimmer-v2 environment.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        batch_size (int): Number of timesteps to use in each training step.

    """
    set_seed(seed)
    with LocalTFRunner(ctxt) as runner:
        env = MetaRLEnv(gym.make('Swimmer-v2'))

        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=500,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo, env)
        runner.train(n_epochs=40, batch_size=batch_size)
Пример #2
0
def run_task(snapshot_config, *_):
    """Run task.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): Configuration
            values for snapshotting.
        *_ (object): Hyperparameters (unused).

    """
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(gym.make('Swimmer-v2'))

        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=500,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo,
                     env,
                     sampler_cls=RaySampler,
                     sampler_args={'seed': seed})
        runner.train(n_epochs=40, batch_size=4000)
Пример #3
0
def run_task(snapshot_config, *_):
    """Run the job.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): Configuration
            values for snapshotting.
        *_ (object): Hyperparameters (unused).

    """
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(normalize(gym.make('InvertedPendulum-v2')))

        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo,
                     env,
                     sampler_cls=ISSampler,
                     sampler_args=dict(n_backtrack=1))
        runner.train(n_epochs=200, batch_size=4000)
Пример #4
0
def multi_env_trpo(ctxt=None, seed=1):
    """Train TRPO on two different PointEnv instances.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    with LocalTFRunner(ctxt) as runner:
        env1 = MetaRLEnv(normalize(PointEnv(goal=(-1., 0.))))
        env2 = MetaRLEnv(normalize(PointEnv(goal=(1., 0.))))
        env = MultiEnvWrapper([env1, env2])

        policy = GaussianMLPPolicy(env_spec=env.spec)

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    gae_lambda=0.95,
                    lr_clip_range=0.2,
                    policy_ent_coeff=0.0)

        runner.setup(algo, env)
        runner.train(n_epochs=40, batch_size=2048, plot=False)
Пример #5
0
    def test_dm_control_tf_policy(self):
        task = ALL_TASKS[0]

        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            env = TfEnv(DmControlEnv.from_suite(*task))

            policy = GaussianMLPPolicy(
                env_spec=env.spec,
                hidden_sizes=(32, 32),
            )

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = TRPO(
                env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                max_path_length=5,
                discount=0.99,
                max_kl_step=0.01,
            )

            runner.setup(algo, env)
            runner.train(n_epochs=1, batch_size=10)

            env.close()
def trpois_inverted_pendulum(ctxt=None, seed=1):
    """Train TRPO on InvertedPendulum-v2 with importance sampling.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    with LocalTFRunner(ctxt) as runner:
        env = MetaRLEnv(normalize(gym.make('InvertedPendulum-v2')))

        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo,
                     env,
                     sampler_cls=ISSampler,
                     sampler_args=dict(n_backtrack=1))
        runner.train(n_epochs=200, batch_size=4000)
    def test_trpo_cnn_cubecrash(self):
        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            env = MetaRLEnv(normalize(gym.make('CubeCrash-v0')))

            policy = CategoricalCNNPolicy(env_spec=env.spec,
                                          filters=((32, (8, 8)), (64, (4, 4))),
                                          strides=(4, 2),
                                          padding='VALID',
                                          hidden_sizes=(32, 32))

            baseline = GaussianCNNBaseline(
                env_spec=env.spec,
                regressor_args=dict(filters=((32, (8, 8)), (64, (4, 4))),
                                    strides=(4, 2),
                                    padding='VALID',
                                    hidden_sizes=(32, 32),
                                    use_trust_region=True))

            algo = TRPO(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        max_path_length=100,
                        discount=0.99,
                        gae_lambda=0.98,
                        max_kl_step=0.01,
                        policy_ent_coeff=0.0,
                        flatten_input=False)

            runner.setup(algo, env)
            last_avg_ret = runner.train(n_epochs=10, batch_size=2048)
            assert last_avg_ret > -1.5

            env.close()
Пример #8
0
def run_task(snapshot_config, *_):
    """Run task."""
    with LocalTFRunner(snapshot_config=snapshot_config,
                       max_cpus=n_envs) as runner:
        env = TfEnv(env_name='CartPole-v1')

        policy = CategoricalMLPPolicy(name='policy',
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=max_path_length,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo=algo,
                     env=env,
                     sampler_cls=BatchSampler,
                     sampler_args={'n_envs': n_envs})

        runner.train(n_epochs=100, batch_size=4000, plot=False)
Пример #9
0
def run_task(snapshot_config, *_):
    """Run task.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.

        _ (object): Ignored by this function.

    """
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env1 = TfEnv(normalize(PointEnv(goal=(-1., 0.))))
        env2 = TfEnv(normalize(PointEnv(goal=(1., 0.))))
        env = MultiEnvWrapper([env1, env2])

        policy = GaussianMLPPolicy(env_spec=env.spec)

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    gae_lambda=0.95,
                    lr_clip_range=0.2,
                    policy_ent_coeff=0.0)

        runner.setup(algo, env)
        runner.train(n_epochs=40, batch_size=2048, plot=False)
Пример #10
0
def trpo_mt50(ctxt=None, seed=1):

    """Run task."""
    set_seed(seed)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        env = MultiEnvWrapper(MT50_envs, env_ids, sample_strategy=round_robin_strategy)

        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(64, 64))

        # baseline = LinearFeatureBaseline(env_spec=env.spec)
        baseline = GaussianMLPBaseline(
            env_spec=env.spec,
            regressor_args=dict(
                hidden_sizes=(64, 64),
                use_trust_region=False,
            ),
        )

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=150,
                    discount=0.99,
                    gae_lambda=0.97,
                    max_kl_step=0.01)

        runner.setup(algo, env)
        runner.train(n_epochs=1500, batch_size=len(MT50_envs)*10*150)
Пример #11
0
def trpo_ml1(ctxt=None, seed=1):
    """Run task."""
    set_seed(seed)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        Ml1_reach_envs = get_ML1_envs_test(env_id)
        env = MTMetaWorldWrapper(Ml1_reach_envs)

        policy = GaussianMLPPolicy(
            env_spec=env.spec,
            hidden_sizes=(64, 64),
            hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None,
        )

        baseline = GaussianMLPBaseline(
            env_spec=env.spec,
            regressor_args=dict(hidden_sizes=(64, 64), use_trust_region=False),
        )

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=150,
                    discount=0.99,
                    gae_lambda=0.97,
                    max_kl_step=0.01)

        timesteps = 6000000
        batch_size = 150 * env.num_tasks
        epochs = timesteps // batch_size

        print(f'epochs: {epochs}, batch_size: {batch_size}')

        runner.setup(algo, env, sampler_args={'n_envs': 1})
        runner.train(n_epochs=epochs, batch_size=batch_size, plot=False)
def trpo_cartpole_recurrent(ctxt, seed, n_epochs, batch_size, plot):
    """Train TRPO with a recurrent policy on CartPole.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        n_epochs (int): Number of epochs for training.
        seed (int): Used to seed the random number generator to produce
            determinism.
        batch_size (int): Batch size used for training.
        plot (bool): Whether to plot or not.

    """
    set_seed(seed)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        env = MetaRLEnv(env_name='CartPole-v1')

        policy = CategoricalLSTMPolicy(name='policy', env_spec=env.spec)

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    max_kl_step=0.01,
                    optimizer=ConjugateGradientOptimizer,
                    optimizer_args=dict(hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)))

        runner.setup(algo, env)
        runner.train(n_epochs=n_epochs, batch_size=batch_size, plot=plot)
Пример #13
0
def run_task(snapshot_config, *_):
    """Defines the main experiment routine.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): Configuration
            values for snapshotting.
        *_ (object): Hyperparameters (unused).

    """
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(env_name='CartPole-v1')

        policy = CategoricalLSTMPolicy(name='policy', env_spec=env.spec)

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    max_kl_step=0.01,
                    optimizer=ConjugateGradientOptimizer,
                    optimizer_args=dict(hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)))

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000)
Пример #14
0
def trpo_gym_tf_cartpole(ctxt=None, seed=1):
    """Train TRPO with CartPole-v0 environment.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        env = MetaRLEnv(gym.make('CartPole-v0'))

        policy = CategoricalMLPPolicy(name='policy',
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(
            env_spec=env.spec,
            policy=policy,
            baseline=baseline,
            max_path_length=200,
            discount=0.99,
            max_kl_step=0.01,
        )

        runner.setup(algo, env)
        runner.train(n_epochs=120, batch_size=4000)
Пример #15
0
 def test_trpo_unknown_kl_constraint(self):
     """Test TRPO with unkown KL constraints."""
     with pytest.raises(ValueError, match='Invalid kl_constraint'):
         TRPO(
             env_spec=self.env.spec,
             policy=self.policy,
             baseline=self.baseline,
             max_path_length=100,
             discount=0.99,
             gae_lambda=0.98,
             policy_ent_coeff=0.0,
             kl_constraint='random kl_constraint',
         )
Пример #16
0
 def test_trpo_pendulum(self):
     """Test TRPO with Pendulum environment."""
     with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
         algo = TRPO(env_spec=self.env.spec,
                     policy=self.policy,
                     baseline=self.baseline,
                     max_path_length=100,
                     discount=0.99,
                     gae_lambda=0.98,
                     policy_ent_coeff=0.0)
         runner.setup(algo, self.env)
         last_avg_ret = runner.train(n_epochs=10, batch_size=2048)
         assert last_avg_ret > 40
Пример #17
0
 def test_trpo_soft_kl_constraint(self):
     """Test TRPO with unkown KL constraints."""
     with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
         algo = TRPO(env_spec=self.env.spec,
                     policy=self.policy,
                     baseline=self.baseline,
                     max_path_length=100,
                     discount=0.99,
                     gae_lambda=0.98,
                     policy_ent_coeff=0.0,
                     kl_constraint='soft')
         runner.setup(algo, self.env)
         last_avg_ret = runner.train(n_epochs=10, batch_size=2048)
         assert last_avg_ret > 45
Пример #18
0
def run_task(snapshot_config, *_):
    """Run task."""
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(gym.make('Swimmer-v2'))

        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=500,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo, env)
        runner.train(n_epochs=40, batch_size=4000)
Пример #19
0
def run_task(snapshot_config, *_):
    """Run task."""
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(env_name='CartPole-v1')

        policy = CategoricalMLPPolicy(name='policy',
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000)
def trpo_cubecrash(ctxt=None, seed=1, batch_size=4000):
    """Train TRPO with CubeCrash-v0 environment.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        batch_size (int): Number of timesteps to use in each training step.

    """
    set_seed(seed)
    with LocalTFRunner(ctxt) as runner:
        env = MetaRLEnv(normalize(gym.make('CubeCrash-v0')))
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      filters=((32, (8, 8)), (64, (4, 4))),
                                      strides=(4, 2),
                                      padding='VALID',
                                      hidden_sizes=(32, 32))

        baseline = GaussianCNNBaseline(
            env_spec=env.spec,
            regressor_args=dict(filters=((32, (8, 8)), (64, (4, 4))),
                                strides=(4, 2),
                                padding='VALID',
                                hidden_sizes=(32, 32),
                                use_trust_region=True))

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    gae_lambda=0.95,
                    lr_clip_range=0.2,
                    policy_ent_coeff=0.0,
                    flatten_input=False)

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=batch_size)
Пример #21
0
def run_task(snapshot_config, variant_data, *_):
    """Run task.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.

        variant_data (dict): Custom arguments for the task.

        *_ (object): Ignored by this function.

    """
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(normalize(gym.make('CubeCrash-v0')))
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      conv_filters=(32, 64),
                                      conv_filter_sizes=(8, 4),
                                      conv_strides=(4, 2),
                                      conv_pad='VALID',
                                      hidden_sizes=(32, 32))

        baseline = GaussianCNNBaseline(env_spec=env.spec,
                                       regressor_args=dict(
                                           num_filters=(32, 64),
                                           filter_dims=(8, 4),
                                           strides=(4, 2),
                                           padding='VALID',
                                           hidden_sizes=(32, 32),
                                           use_trust_region=True))

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    max_kl_step=0.01,
                    flatten_input=False)

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=variant_data['batch_size'])
Пример #22
0
def trpo_cartpole_batch_sampler(ctxt=None,
                                seed=1,
                                batch_size=4000,
                                max_path_length=100):
    """Train TRPO with CartPole-v1 environment.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        batch_size (int): Number of timesteps to use in each training step.
        max_path_length (int): Number of timesteps to truncate paths to.

    """
    set_seed(seed)
    n_envs = batch_size // max_path_length
    with LocalTFRunner(ctxt, max_cpus=n_envs) as runner:
        env = MetaRLEnv(env_name='CartPole-v1')

        policy = CategoricalMLPPolicy(name='policy',
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=max_path_length,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo=algo,
                     env=env,
                     sampler_cls=BatchSampler,
                     sampler_args={'n_envs': n_envs})

        runner.train(n_epochs=100, batch_size=4000, plot=False)
Пример #23
0
def trpo_swimmer_ray_sampler(ctxt=None, seed=1):
    """tf_trpo_swimmer.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.


    """
    # Since this is an example, we are running ray in a reduced state.
    # One can comment this line out in order to run ray at full capacity
    ray.init(memory=52428800,
             object_store_memory=78643200,
             ignore_reinit_error=True,
             log_to_driver=False,
             include_webui=False)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        set_seed(seed)
        env = MetaRLEnv(gym.make('Swimmer-v2'))

        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=500,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo,
                     env,
                     sampler_cls=RaySampler,
                     sampler_args={'seed': seed})
        runner.train(n_epochs=40, batch_size=4000)
Пример #24
0
    def test_trpo_lstm_cartpole(self):
        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            env = MetaRLEnv(normalize(gym.make('CartPole-v1')))

            policy = CategoricalLSTMPolicy(name='policy', env_spec=env.spec)

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = TRPO(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        max_path_length=100,
                        discount=0.99,
                        max_kl_step=0.01,
                        optimizer_args=dict(hvp_approach=FiniteDifferenceHvp(
                            base_eps=1e-5)))

            snapshotter.snapshot_dir = './'
            runner.setup(algo, env)
            last_avg_ret = runner.train(n_epochs=10, batch_size=2048)
            assert last_avg_ret > 80

            env.close()
Пример #25
0
    def test_gaussian_policies(self, policy_cls):
        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            env = MetaRLEnv(normalize(gym.make('Pendulum-v0')))

            policy = policy_cls(name='policy', env_spec=env.spec)

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = TRPO(
                env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                max_path_length=100,
                discount=0.99,
                max_kl_step=0.01,
                optimizer=ConjugateGradientOptimizer,
                optimizer_args=dict(hvp_approach=FiniteDifferenceHvp(
                    base_eps=1e-5)),
            )

            runner.setup(algo, env)
            runner.train(n_epochs=1, batch_size=4000)
            env.close()
Пример #26
0
def trpo_metarl_tf(ctxt, env_id, seed):
    """Create metarl Tensorflow TROI model and training.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the
            snapshotter.
        env_id (str): Environment id of the task.
        seed (int): Random positive integer for the trial.

    """
    deterministic.set_seed(seed)

    with LocalTFRunner(ctxt) as runner:
        env = MetaRLEnv(normalize(gym.make(env_id)))

        policy = GaussianMLPPolicy(
            env_spec=env.spec,
            hidden_sizes=hyper_parameters['hidden_sizes'],
            hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None,
        )

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=hyper_parameters['max_path_length'],
                    discount=hyper_parameters['discount'],
                    gae_lambda=hyper_parameters['gae_lambda'],
                    max_kl_step=hyper_parameters['max_kl'])

        runner.setup(algo, env)
        runner.train(n_epochs=hyper_parameters['n_epochs'],
                     batch_size=hyper_parameters['batch_size'])
    def test_is_sampler(self):
        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            env = MetaRLEnv(normalize(gym.make('InvertedPendulum-v2')))
            policy = GaussianMLPPolicy(env_spec=env.spec,
                                       hidden_sizes=(32, 32))
            baseline = LinearFeatureBaseline(env_spec=env.spec)
            algo = TRPO(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        max_path_length=100,
                        discount=0.99,
                        max_kl_step=0.01)

            runner.setup(algo,
                         env,
                         sampler_cls=ISSampler,
                         sampler_args=dict(n_backtrack=1, init_is=1))
            runner._start_worker()

            paths = runner._sampler.obtain_samples(1)
            assert paths == [], 'Should return empty paths if no history'

            # test importance and live sampling get called alternatively
            with unittest.mock.patch.object(ISSampler,
                                            '_obtain_is_samples') as mocked:
                assert runner._sampler.obtain_samples(2, 20)
                mocked.assert_not_called()

                assert runner._sampler.obtain_samples(3)
                mocked.assert_called_once_with(3, None, True)

            # test importance sampling for first n_is_pretrain iterations
            with unittest.mock.patch.object(ISSampler,
                                            '_obtain_is_samples') as mocked:
                runner._sampler.n_is_pretrain = 5
                runner._sampler.n_backtrack = None
                runner._sampler.obtain_samples(4)

                mocked.assert_called_once_with(4, None, True)

            runner._sampler.obtain_samples(5)

            # test random draw important samples
            runner._sampler.randomize_draw = True
            assert runner._sampler.obtain_samples(6, 1)
            runner._sampler.randomize_draw = False

            runner._sampler.obtain_samples(7, 30)

            # test ess_threshold use
            runner._sampler.ess_threshold = 500
            paths = runner._sampler.obtain_samples(8, 30)
            assert paths == [], (
                'Should return empty paths when ess_threshold is large')
            runner._sampler.ess_threshold = 0

            # test random sample selection when len(paths) > batch size
            runner._sampler.n_is_pretrain = 15
            runner._sampler.obtain_samples(9, 10)
            runner._sampler.obtain_samples(10, 1)

            runner._shutdown_worker()