예제 #1
0
def test_update_envs_env_update():
    max_path_length = 16
    env = TfEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_path_length)
                         ])
    tasks = SetTaskSampler(PointEnv)
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)
    sampler = LocalSampler.from_worker_factory(workers, policy, env)
    rollouts = sampler.obtain_samples(0,
                                      161,
                                      np.asarray(policy.get_param_values()),
                                      env_update=tasks.sample(n_workers))
    mean_rewards = []
    goals = []
    for rollout in rollouts.split():
        mean_rewards.append(rollout.rewards.mean())
        goals.append(rollout.env_infos['task'][0]['goal'])
    assert np.var(mean_rewards) > 0
    assert np.var(goals) > 0
    with pytest.raises(ValueError):
        sampler.obtain_samples(0,
                               10,
                               np.asarray(policy.get_param_values()),
                               env_update=tasks.sample(n_workers + 1))
def test_pickle():
    max_path_length = 16
    env = MetaRLEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_path_length)
                         ])
    tasks = SetTaskSampler(PointEnv)
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)
    sampler = MultiprocessingSampler.from_worker_factory(workers, policy, env)
    sampler_pickled = pickle.dumps(sampler)
    sampler.shutdown_worker()
    sampler2 = pickle.loads(sampler_pickled)
    rollouts = sampler2.obtain_samples(0,
                                       161,
                                       np.asarray(policy.get_param_values()),
                                       env_update=tasks.sample(n_workers))
    mean_rewards = []
    goals = []
    for rollout in rollouts.split():
        mean_rewards.append(rollout.rewards.mean())
        goals.append(rollout.env_infos['task'][0]['goal'])
    assert np.var(mean_rewards) > 0
    assert np.var(goals) > 0
    sampler2.shutdown_worker()
    env.close()
def test_init_with_crashed_worker():
    max_path_length = 16
    env = MetaRLEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_path_length)
                         ])
    tasks = SetTaskSampler(lambda: MetaRLEnv(PointEnv()))
    n_workers = 2
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)

    class CrashingPolicy:
        def reset(self, **kwargs):
            raise Exception('Intentional subprocess crash')

    bad_policy = CrashingPolicy()

    #  This causes worker 2 to crash.
    sampler = MultiprocessingSampler.from_worker_factory(
        workers, [policy, bad_policy], envs=tasks.sample(n_workers))
    rollouts = sampler.obtain_samples(0, 160, None)
    assert sum(rollouts.lengths) >= 160
    sampler.shutdown_worker()
    env.close()
def test_meta_evaluator():
    set_seed(100)
    tasks = SetTaskSampler(lambda: MetaRLEnv(PointEnv()))
    max_path_length = 200
    with tempfile.TemporaryDirectory() as log_dir_name:
        runner = LocalRunner(
            SnapshotConfig(snapshot_dir=log_dir_name,
                           snapshot_mode='last',
                           snapshot_gap=1))
        env = MetaRLEnv(PointEnv())
        algo = OptimalActionInference(env=env, max_path_length=max_path_length)
        runner.setup(algo, env)
        meta_eval = MetaEvaluator(test_task_sampler=tasks,
                                  max_path_length=max_path_length,
                                  n_test_tasks=10)
        log_file = tempfile.NamedTemporaryFile()
        csv_output = CsvOutput(log_file.name)
        logger.add_output(csv_output)
        meta_eval.evaluate(algo)
        logger.log(tabular)
        meta_eval.evaluate(algo)
        logger.log(tabular)
        logger.dump_output_type(CsvOutput)
        logger.remove_output_type(CsvOutput)
        with open(log_file.name, 'r') as file:
            rows = list(csv.DictReader(file))
        assert len(rows) == 2
        assert float(rows[0]['MetaTest/__unnamed_task__/CompletionRate']) < 1.0
        assert float(rows[0]['MetaTest/__unnamed_task__/Iteration']) == 0
        assert (float(rows[0]['MetaTest/__unnamed_task__/MaxReturn']) >= float(
            rows[0]['MetaTest/__unnamed_task__/AverageReturn']))
        assert (float(rows[0]['MetaTest/__unnamed_task__/AverageReturn']) >=
                float(rows[0]['MetaTest/__unnamed_task__/MinReturn']))
        assert float(rows[1]['MetaTest/__unnamed_task__/Iteration']) == 1
def test_meta_evaluator_with_tf():
    set_seed(100)
    tasks = SetTaskSampler(lambda: MetaRLEnv(PointEnv()))
    max_path_length = 200
    env = MetaRLEnv(PointEnv())
    n_traj = 3
    with tempfile.TemporaryDirectory() as log_dir_name:
        ctxt = SnapshotConfig(snapshot_dir=log_dir_name,
                              snapshot_mode='none',
                              snapshot_gap=1)
        with LocalTFRunner(ctxt) as runner:
            meta_eval = MetaEvaluator(test_task_sampler=tasks,
                                      max_path_length=max_path_length,
                                      n_test_tasks=10,
                                      n_exploration_traj=n_traj)
            policy = GaussianMLPPolicy(env.spec)
            algo = MockTFAlgo(env, policy, max_path_length, n_traj, meta_eval)
            runner.setup(algo, env)
            log_file = tempfile.NamedTemporaryFile()
            csv_output = CsvOutput(log_file.name)
            logger.add_output(csv_output)
            meta_eval.evaluate(algo)
            algo_pickle = cloudpickle.dumps(algo)
        tf.compat.v1.reset_default_graph()
        with LocalTFRunner(ctxt) as runner:
            algo2 = cloudpickle.loads(algo_pickle)
            runner.setup(algo2, env)
            runner.train(10, 0)
def test_pickle_meta_evaluator():
    set_seed(100)
    tasks = SetTaskSampler(lambda: MetaRLEnv(PointEnv()))
    max_path_length = 200
    env = MetaRLEnv(PointEnv())
    n_traj = 3
    with tempfile.TemporaryDirectory() as log_dir_name:
        runner = LocalRunner(
            SnapshotConfig(snapshot_dir=log_dir_name,
                           snapshot_mode='last',
                           snapshot_gap=1))
        meta_eval = MetaEvaluator(test_task_sampler=tasks,
                                  max_path_length=max_path_length,
                                  n_test_tasks=10,
                                  n_exploration_traj=n_traj)
        policy = RandomPolicy(env.spec.action_space)
        algo = MockAlgo(env, policy, max_path_length, n_traj, meta_eval)
        runner.setup(algo, env)
        log_file = tempfile.NamedTemporaryFile()
        csv_output = CsvOutput(log_file.name)
        logger.add_output(csv_output)
        meta_eval.evaluate(algo)
        meta_eval_pickle = cloudpickle.dumps(meta_eval)
        meta_eval2 = cloudpickle.loads(meta_eval_pickle)
        meta_eval2.evaluate(algo)
    def test_ppo_pendulum(self):
        """Test PPO with Pendulum environment."""
        deterministic.set_seed(0)

        rollouts_per_task = 5
        max_path_length = 100

        task_sampler = SetTaskSampler(lambda: MetaRLEnv(
            normalize(HalfCheetahDirEnv(), expected_action_scale=10.)))

        meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler,
                                       max_path_length=max_path_length,
                                       n_test_tasks=1,
                                       n_test_rollouts=10)

        runner = LocalRunner(snapshot_config)
        algo = MAMLVPG(env=self.env,
                       policy=self.policy,
                       value_function=self.value_function,
                       max_path_length=max_path_length,
                       meta_batch_size=5,
                       discount=0.99,
                       gae_lambda=1.,
                       inner_lr=0.1,
                       num_grad_updates=1,
                       meta_evaluator=meta_evaluator)

        runner.setup(algo, self.env)
        last_avg_ret = runner.train(n_epochs=10,
                                    batch_size=rollouts_per_task *
                                    max_path_length)

        assert last_avg_ret > -5
예제 #8
0
def test_init_with_env_updates():
    max_path_length = 16
    env = TfEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_path_length)
                         ])
    tasks = SetTaskSampler(lambda: TfEnv(PointEnv()))
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)
    sampler = LocalSampler.from_worker_factory(workers,
                                               policy,
                                               envs=tasks.sample(n_workers))
    rollouts = sampler.obtain_samples(0, 160, policy)
    assert sum(rollouts.lengths) >= 160
def maml_vpg_half_cheetah_dir(ctxt, seed, epochs, rollouts_per_task,
                              meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        rollouts_per_task (int): Number of rollouts per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    env = MetaRLEnv(normalize(HalfCheetahDirEnv(), expected_action_scale=10.))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    max_path_length = 100

    task_sampler = SetTaskSampler(lambda: MetaRLEnv(
        normalize(HalfCheetahDirEnv(), expected_action_scale=10.)))

    meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler,
                                   max_path_length=max_path_length,
                                   n_test_tasks=1,
                                   n_test_rollouts=10)

    runner = LocalRunner(ctxt)
    algo = MAMLVPG(env=env,
                   policy=policy,
                   value_function=value_function,
                   max_path_length=max_path_length,
                   meta_batch_size=meta_batch_size,
                   discount=0.99,
                   gae_lambda=1.,
                   inner_lr=0.1,
                   num_grad_updates=1,
                   meta_evaluator=meta_evaluator)

    runner.setup(algo, env)
    runner.train(n_epochs=epochs,
                 batch_size=rollouts_per_task * max_path_length)
예제 #10
0
def test_init_with_env_updates(ray_local_session_fixture):
    del ray_local_session_fixture
    assert ray.is_initialized()
    max_path_length = 16
    env = MetaRLEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_path_length)
                         ])
    tasks = SetTaskSampler(lambda: MetaRLEnv(PointEnv()))
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)
    sampler = RaySampler.from_worker_factory(workers,
                                             policy,
                                             envs=tasks.sample(n_workers))
    rollouts = sampler.obtain_samples(0, 160, policy)
    assert sum(rollouts.lengths) >= 160
    def test_pickling(self):
        """Test pickle and unpickle."""
        net_size = 10
        env_sampler = SetTaskSampler(PointEnv)
        env = env_sampler.sample(5)

        test_env_sampler = SetTaskSampler(PointEnv)

        augmented_env = PEARL.augment_env_spec(env[0](), 5)
        qf = ContinuousMLPQFunction(
            env_spec=augmented_env,
            hidden_sizes=[net_size, net_size, net_size])

        vf_env = PEARL.get_env_spec(env[0](), 5, 'vf')
        vf = ContinuousMLPQFunction(
            env_spec=vf_env, hidden_sizes=[net_size, net_size, net_size])

        inner_policy = TanhGaussianMLPPolicy(
            env_spec=augmented_env,
            hidden_sizes=[net_size, net_size, net_size])

        pearl = PEARL(env=env,
                      inner_policy=inner_policy,
                      qf=qf,
                      vf=vf,
                      num_train_tasks=5,
                      num_test_tasks=5,
                      latent_dim=5,
                      encoder_hidden_sizes=[10, 10],
                      test_env_sampler=test_env_sampler)

        # This line is just to improve coverage
        pearl.to()

        pickled = pickle.dumps(pearl)
        unpickled = pickle.loads(pickled)

        assert hasattr(unpickled, '_replay_buffers')
        assert hasattr(unpickled, '_context_replay_buffers')
        assert unpickled._is_resuming
예제 #12
0
    def test_benchmark_pearl(self):
        '''
        Compare benchmarks between metarl and baselines.
        :return:
        '''
        env_sampler = SetTaskSampler(
            lambda: MetaRLEnv(normalize(ML1.get_train_tasks('reach-v1'))))
        env = env_sampler.sample(params['num_train_tasks'])
        test_env_sampler = SetTaskSampler(
            lambda: MetaRLEnv(normalize(ML1.get_test_tasks('reach-v1'))))
        test_env = test_env_sampler.sample(params['num_train_tasks'])
        env_id = 'reach-v1'
        timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
        benchmark_dir = osp.join(os.getcwd(), 'data', 'local', 'benchmarks',
                                 'pearl', timestamp)
        result_json = {}
        seeds = random.sample(range(100), params['n_trials'])
        task_dir = osp.join(benchmark_dir, env_id)
        plt_file = osp.join(benchmark_dir, '{}_benchmark.png'.format(env_id))
        metarl_csvs = []

        for trial in range(params['n_trials']):
            seed = seeds[trial]
            trial_dir = task_dir + '/trial_%d_seed_%d' % (trial + 1, seed)
            metarl_dir = trial_dir + '/metarl'

            metarl_csv = run_metarl(env, test_env, seed, metarl_dir)
            metarl_csvs.append(metarl_csv)

        env.close()

        benchmark_helper.plot_average_over_trials(
            [metarl_csvs],
            ys=['Test/Average/SuccessRate'],
            plt_file=plt_file,
            env_id=env_id,
            x_label='TotalEnvSteps',
            y_label='Test/Average/SuccessRate',
            names=['metarl_pearl'],
        )

        factor_val = params['meta_batch_size'] * params['max_path_length']
        result_json[env_id] = benchmark_helper.create_json(
            [metarl_csvs],
            seeds=seeds,
            trials=params['n_trials'],
            xs=['TotalEnvSteps'],
            ys=['Test/Average/SuccessRate'],
            factors=[factor_val],
            names=['metarl_pearl'])

        Rh.write_file(result_json, 'PEARL')
예제 #13
0
def test_meta_evaluator_n_traj():
    set_seed(100)
    tasks = SetTaskSampler(PointEnv)
    max_path_length = 200
    env = MetaRLEnv(PointEnv())
    n_traj = 3
    with tempfile.TemporaryDirectory() as log_dir_name:
        runner = LocalRunner(
            SnapshotConfig(snapshot_dir=log_dir_name,
                           snapshot_mode='last',
                           snapshot_gap=1))
        algo = MockAlgo(env, max_path_length, n_traj)
        runner.setup(algo, env)
        meta_eval = MetaEvaluator(runner,
                                  test_task_sampler=tasks,
                                  max_path_length=max_path_length,
                                  n_test_tasks=10,
                                  n_exploration_traj=n_traj)
        log_file = tempfile.NamedTemporaryFile()
        csv_output = CsvOutput(log_file.name)
        logger.add_output(csv_output)
        meta_eval.evaluate(algo)
예제 #14
0
def run_task(snapshot_config, *_):
    """Set up environment and algorithm and run the task.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        _ : Unused parameters

    """
    # create multi-task environment and sample tasks
    env_sampler = SetTaskSampler(
        lambda: MetaRLEnv(normalize(ML1.get_train_tasks('push-v1'))))
    env = env_sampler.sample(params['num_train_tasks'])
    test_env_sampler = SetTaskSampler(
        lambda: MetaRLEnv(normalize(ML1.get_test_tasks('push-v1'))))
    test_env = test_env_sampler.sample(params['num_train_tasks'])

    runner = LocalRunner(snapshot_config)
    obs_dim = int(np.prod(env[0]().observation_space.shape))
    action_dim = int(np.prod(env[0]().action_space.shape))
    reward_dim = 1

    # instantiate networks
    encoder_in_dim = obs_dim + action_dim + reward_dim
    encoder_out_dim = params['latent_size'] * 2
    net_size = params['net_size']

    context_encoder = MLPEncoder(input_dim=encoder_in_dim,
                                 output_dim=encoder_out_dim,
                                 hidden_sizes=[200, 200, 200])

    space_a = akro.Box(low=-1,
                       high=1,
                       shape=(obs_dim + params['latent_size'], ),
                       dtype=np.float32)
    space_b = akro.Box(low=-1, high=1, shape=(action_dim, ), dtype=np.float32)
    augmented_env = EnvSpec(space_a, space_b)

    qf1 = ContinuousMLPQFunction(env_spec=augmented_env,
                                 hidden_sizes=[net_size, net_size, net_size])

    qf2 = ContinuousMLPQFunction(env_spec=augmented_env,
                                 hidden_sizes=[net_size, net_size, net_size])

    obs_space = akro.Box(low=-1, high=1, shape=(obs_dim, ), dtype=np.float32)
    action_space = akro.Box(low=-1,
                            high=1,
                            shape=(params['latent_size'], ),
                            dtype=np.float32)
    vf_env = EnvSpec(obs_space, action_space)

    vf = ContinuousMLPQFunction(env_spec=vf_env,
                                hidden_sizes=[net_size, net_size, net_size])

    policy = TanhGaussianMLPPolicy2(
        env_spec=augmented_env, hidden_sizes=[net_size, net_size, net_size])

    context_conditioned_policy = ContextConditionedPolicy(
        latent_dim=params['latent_size'],
        context_encoder=context_encoder,
        policy=policy,
        use_ib=params['use_information_bottleneck'],
        use_next_obs=params['use_next_obs_in_context'],
    )

    pearlsac = PEARLSAC(
        env=env,
        test_env=test_env,
        policy=context_conditioned_policy,
        qf1=qf1,
        qf2=qf2,
        vf=vf,
        num_train_tasks=params['num_train_tasks'],
        num_test_tasks=params['num_test_tasks'],
        latent_dim=params['latent_size'],
        meta_batch_size=params['meta_batch_size'],
        num_steps_per_epoch=params['num_steps_per_epoch'],
        num_initial_steps=params['num_initial_steps'],
        num_tasks_sample=params['num_tasks_sample'],
        num_steps_prior=params['num_steps_prior'],
        num_extra_rl_steps_posterior=params['num_extra_rl_steps_posterior'],
        num_evals=params['num_evals'],
        num_steps_per_eval=params['num_steps_per_eval'],
        batch_size=params['batch_size'],
        embedding_batch_size=params['embedding_batch_size'],
        embedding_mini_batch_size=params['embedding_mini_batch_size'],
        max_path_length=params['max_path_length'],
        reward_scale=params['reward_scale'],
    )

    tu.set_gpu_mode(params['use_gpu'], gpu_id=0)
    if params['use_gpu']:
        pearlsac.to()

    runner.setup(algo=pearlsac,
                 env=env,
                 sampler_cls=PEARLSampler,
                 sampler_args=dict(max_path_length=params['max_path_length']))
    runner.train(n_epochs=params['num_epochs'],
                 batch_size=params['batch_size'])
def pearl_metaworld_ml1_push(ctxt=None,
                             seed=1,
                             num_epochs=1000,
                             num_train_tasks=50,
                             num_test_tasks=10,
                             latent_size=7,
                             encoder_hidden_size=200,
                             net_size=300,
                             meta_batch_size=16,
                             num_steps_per_epoch=4000,
                             num_initial_steps=4000,
                             num_tasks_sample=15,
                             num_steps_prior=750,
                             num_extra_rl_steps_posterior=750,
                             batch_size=256,
                             embedding_batch_size=64,
                             embedding_mini_batch_size=64,
                             max_path_length=150,
                             reward_scale=10.,
                             use_gpu=False):
    """Train PEARL with ML1 environments.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        num_epochs (int): Number of training epochs.
        num_train_tasks (int): Number of tasks for training.
        num_test_tasks (int): Number of tasks for testing.
        latent_size (int): Size of latent context vector.
        encoder_hidden_size (int): Output dimension of dense layer of the
            context encoder.
        net_size (int): Output dimension of a dense layer of Q-function and
            value function.
        meta_batch_size (int): Meta batch size.
        num_steps_per_epoch (int): Number of iterations per epoch.
        num_initial_steps (int): Number of transitions obtained per task before
            training.
        num_tasks_sample (int): Number of random tasks to obtain data for each
            iteration.
        num_steps_prior (int): Number of transitions to obtain per task with
            z ~ prior.
        num_extra_rl_steps_posterior (int): Number of additional transitions
            to obtain per task with z ~ posterior that are only used to train
            the policy and NOT the encoder.
        batch_size (int): Number of transitions in RL batch.
        embedding_batch_size (int): Number of transitions in context batch.
        embedding_mini_batch_size (int): Number of transitions in mini context
            batch; should be same as embedding_batch_size for non-recurrent
            encoder.
        max_path_length (int): Maximum path length.
        reward_scale (int): Reward scale.
        use_gpu (bool): Whether or not to use GPU for training.

    """
    set_seed(seed)
    encoder_hidden_sizes = (encoder_hidden_size, encoder_hidden_size,
                            encoder_hidden_size)
    # create multi-task environment and sample tasks
    env_sampler = SetTaskSampler(lambda: MetaRLEnv(
        normalize(mwb.ML1.get_train_tasks('push-v1'))))
    env = env_sampler.sample(num_train_tasks)

    test_env_sampler = SetTaskSampler(lambda: MetaRLEnv(
        normalize(mwb.ML1.get_test_tasks('push-v1'))))

    runner = LocalRunner(ctxt)

    # instantiate networks
    augmented_env = PEARL.augment_env_spec(env[0](), latent_size)
    qf = ContinuousMLPQFunction(env_spec=augmented_env,
                                hidden_sizes=[net_size, net_size, net_size])

    vf_env = PEARL.get_env_spec(env[0](), latent_size, 'vf')
    vf = ContinuousMLPQFunction(env_spec=vf_env,
                                hidden_sizes=[net_size, net_size, net_size])

    inner_policy = TanhGaussianMLPPolicy(
        env_spec=augmented_env, hidden_sizes=[net_size, net_size, net_size])

    pearl = PEARL(
        env=env,
        policy_class=ContextConditionedPolicy,
        encoder_class=MLPEncoder,
        inner_policy=inner_policy,
        qf=qf,
        vf=vf,
        num_train_tasks=num_train_tasks,
        num_test_tasks=num_test_tasks,
        latent_dim=latent_size,
        encoder_hidden_sizes=encoder_hidden_sizes,
        test_env_sampler=test_env_sampler,
        meta_batch_size=meta_batch_size,
        num_steps_per_epoch=num_steps_per_epoch,
        num_initial_steps=num_initial_steps,
        num_tasks_sample=num_tasks_sample,
        num_steps_prior=num_steps_prior,
        num_extra_rl_steps_posterior=num_extra_rl_steps_posterior,
        batch_size=batch_size,
        embedding_batch_size=embedding_batch_size,
        embedding_mini_batch_size=embedding_mini_batch_size,
        max_path_length=max_path_length,
        reward_scale=reward_scale,
    )

    set_gpu_mode(use_gpu, gpu_id=0)
    if use_gpu:
        pearl.to()

    runner.setup(algo=pearl,
                 env=env[0](),
                 sampler_cls=LocalSampler,
                 sampler_args=dict(max_path_length=max_path_length),
                 n_workers=1,
                 worker_class=PEARLWorker)

    runner.train(n_epochs=num_epochs, batch_size=batch_size)
    def test_pearl_ml1_push(self):
        """Test PEARL with ML1 Push environment."""
        params = dict(seed=1,
                      num_epochs=1,
                      num_train_tasks=5,
                      num_test_tasks=1,
                      latent_size=7,
                      encoder_hidden_sizes=[10, 10, 10],
                      net_size=30,
                      meta_batch_size=16,
                      num_steps_per_epoch=40,
                      num_initial_steps=40,
                      num_tasks_sample=15,
                      num_steps_prior=15,
                      num_extra_rl_steps_posterior=15,
                      batch_size=256,
                      embedding_batch_size=8,
                      embedding_mini_batch_size=8,
                      max_path_length=50,
                      reward_scale=10.,
                      use_information_bottleneck=True,
                      use_next_obs_in_context=False,
                      use_gpu=False)

        net_size = params['net_size']
        set_seed(params['seed'])
        env_sampler = SetTaskSampler(
            lambda: MetaRLEnv(normalize(ML1.get_train_tasks('push-v1'))))
        env = env_sampler.sample(params['num_train_tasks'])

        test_env_sampler = SetTaskSampler(
            lambda: MetaRLEnv(normalize(ML1.get_test_tasks('push-v1'))))

        augmented_env = PEARL.augment_env_spec(env[0](), params['latent_size'])
        qf = ContinuousMLPQFunction(
            env_spec=augmented_env,
            hidden_sizes=[net_size, net_size, net_size])

        vf_env = PEARL.get_env_spec(env[0](), params['latent_size'], 'vf')
        vf = ContinuousMLPQFunction(
            env_spec=vf_env, hidden_sizes=[net_size, net_size, net_size])

        inner_policy = TanhGaussianMLPPolicy(
            env_spec=augmented_env,
            hidden_sizes=[net_size, net_size, net_size])

        pearl = PEARL(
            env=env,
            policy_class=ContextConditionedPolicy,
            encoder_class=MLPEncoder,
            inner_policy=inner_policy,
            qf=qf,
            vf=vf,
            num_train_tasks=params['num_train_tasks'],
            num_test_tasks=params['num_test_tasks'],
            latent_dim=params['latent_size'],
            encoder_hidden_sizes=params['encoder_hidden_sizes'],
            test_env_sampler=test_env_sampler,
            meta_batch_size=params['meta_batch_size'],
            num_steps_per_epoch=params['num_steps_per_epoch'],
            num_initial_steps=params['num_initial_steps'],
            num_tasks_sample=params['num_tasks_sample'],
            num_steps_prior=params['num_steps_prior'],
            num_extra_rl_steps_posterior=params[
                'num_extra_rl_steps_posterior'],
            batch_size=params['batch_size'],
            embedding_batch_size=params['embedding_batch_size'],
            embedding_mini_batch_size=params['embedding_mini_batch_size'],
            max_path_length=params['max_path_length'],
            reward_scale=params['reward_scale'],
        )

        set_gpu_mode(params['use_gpu'], gpu_id=0)
        if params['use_gpu']:
            pearl.to()

        runner = LocalRunner(snapshot_config)
        runner.setup(
            algo=pearl,
            env=env[0](),
            sampler_cls=LocalSampler,
            sampler_args=dict(max_path_length=params['max_path_length']),
            n_workers=1,
            worker_class=PEARLWorker)

        runner.train(n_epochs=params['num_epochs'],
                     batch_size=params['batch_size'])