Beispiel #1
0
def maml_trpo(ctxt, seed, epochs, rollouts_per_task, meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        rollouts_per_task (int): Number of rollouts per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    env = GarageEnv(
        normalize(ML10.get_train_tasks(), expected_action_scale=10.))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(100, 100),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    max_path_length = 100

    test_task_names = ML10.get_test_tasks().all_task_names
    test_tasks = [
        GarageEnv(normalize(ML10.from_task(task), expected_action_scale=10.))
        for task in test_task_names
    ]
    test_sampler = EnvPoolSampler(test_tasks)

    meta_evaluator = MetaEvaluator(test_task_sampler=test_sampler,
                                   max_path_length=max_path_length,
                                   n_test_tasks=len(test_task_names))

    runner = LocalRunner(ctxt)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    value_function=value_function,
                    max_path_length=max_path_length,
                    meta_batch_size=meta_batch_size,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1,
                    meta_evaluator=meta_evaluator)

    runner.setup(algo, env)
    runner.train(n_epochs=epochs,
                 batch_size=rollouts_per_task * max_path_length)
Beispiel #2
0
def test_task_name():
    task_names = MEDIUM_MODE_CLS_DICT['test'].keys()
    env = ML10.get_test_tasks()
    assert sorted(env.all_task_names) == sorted(task_names)

    _, _, _, info = env.step(env.action_space.sample())
    assert info['task_name'] in task_names
Beispiel #3
0
def test_all_ml10():
    ml10_train_env = ML10.get_train_tasks()
    train_tasks = ml10_train_env.sample_tasks(11)
    for t in train_tasks:
        ml10_train_env.set_task(t)
        step_env(ml10_train_env, max_path_length=3)

    ml10_train_env.close()
    del ml10_train_env

    ml10_test_env = ML10.get_test_tasks()
    test_tasks = ml10_test_env.sample_tasks(11)
    for t in test_tasks:
        ml10_test_env.set_task(t)
        step_env(ml10_test_env, max_path_length=3)

    ml10_test_env.close()
    del ml10_test_env
Beispiel #4
0
def get_metaworld_tasks(env_id: str = 'ml10'):
    def _extract_tasks(env_, skip_task_idxs=[]):
        task_idxs = set()
        tasks = [None for _ in range(env.num_tasks - len(skip_task_idxs))]
        while len(task_idxs) < env.num_tasks - len(skip_task_idxs):
            task_dict = env.sample_tasks(1)[0]
            task_idx = task_dict['task']
            if task_idx not in task_idxs and task_idx not in skip_task_idxs:
                task_idxs.add(task_idx)
                tasks[task_idx - len(skip_task_idxs)] = task_dict
        return tasks

    if env_id == 'ml10':
        from metaworld.benchmarks import ML10
        if args.mltest:
            env = ML10.get_test_tasks()
            tasks = _extract_tasks(env)
        else:
            env = ML10.get_train_tasks()
            tasks = _extract_tasks(env, skip_task_idxs=[])

        if args.task_idx is not None:
            tasks = [tasks[args.task_idx]]

        env.tasks = tasks
        print(tasks)

        def set_task_idx(idx):
            env.set_task(tasks[idx])

        def task_description(batch: None, one_hot: bool = True):
            one_hot = env.active_task_one_hot.astype(np.float32)
            if batch:
                one_hot = one_hot[None, :].repeat(batch, 0)
            return one_hot

        env.set_task_idx = set_task_idx
        env.task_description = task_description
        env.task_description_dim = lambda: len(env.tasks)
        env._max_episode_steps = 150

        return env
    else:
        raise NotImplementedError()
Beispiel #5
0
    def test_rl2_ppo_ml10(self):
        # pylint: disable=import-outside-toplevel
        from metaworld.benchmarks import ML10
        ML_train_envs = [
            RL2Env(ML10.from_task(task_name))
            for task_name in ML10.get_train_tasks().all_task_names
        ]
        tasks = task_sampler.EnvPoolSampler(ML_train_envs)
        tasks.grow_pool(self.meta_batch_size)

        env_spec = ML_train_envs[0].spec
        policy = GaussianGRUPolicy(env_spec=env_spec,
                                   hidden_dim=64,
                                   state_include_action=False,
                                   name='policy')
        baseline = LinearFeatureBaseline(env_spec=env_spec)
        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            algo = RL2PPO(rl2_max_path_length=self.max_path_length,
                          meta_batch_size=self.meta_batch_size,
                          task_sampler=tasks,
                          env_spec=env_spec,
                          policy=policy,
                          baseline=baseline,
                          discount=0.99,
                          gae_lambda=0.95,
                          lr_clip_range=0.2,
                          stop_entropy_gradient=True,
                          entropy_method='max',
                          policy_ent_coeff=0.02,
                          center_adv=False,
                          max_path_length=self.max_path_length *
                          self.episode_per_task)

            runner.setup(
                algo,
                self.tasks.sample(self.meta_batch_size),
                sampler_cls=LocalSampler,
                n_workers=self.meta_batch_size,
                worker_class=RL2Worker,
                worker_args=dict(n_paths_per_trial=self.episode_per_task))

            runner.train(n_epochs=1,
                         batch_size=self.episode_per_task *
                         self.max_path_length * self.meta_batch_size)
Beispiel #6
0
def test_env_pool_sampler():
    # Import, construct environments here to avoid using up too much
    # resources if this test isn't run.
    # pylint: disable=import-outside-toplevel
    from metaworld.benchmarks import ML10
    train_tasks = ML10.get_train_tasks().all_task_names
    ML10_train_envs = [
        ML10.from_task(train_task) for train_task in train_tasks
    ]
    tasks = task_sampler.EnvPoolSampler(ML10_train_envs)
    assert tasks.n_tasks == 10
    updates = tasks.sample(10)
    for env in ML10_train_envs:
        assert any(env is update() for update in updates)
    with pytest.raises(ValueError):
        tasks.sample(10, with_replacement=True)
    with pytest.raises(ValueError):
        tasks.sample(11)
    tasks.grow_pool(20)
    tasks.sample(20)
Beispiel #7
0
def test_construct_envs_sampler_ml10():
    # pylint: disable=import-outside-toplevel
    from metaworld.benchmarks import ML10
    train_tasks = ML10.get_train_tasks().all_task_names
    ML10_constructors = [
        functools.partial(ML10.from_task, train_task)
        for train_task in train_tasks
    ]
    tasks = task_sampler.ConstructEnvsSampler(ML10_constructors)
    assert tasks.n_tasks == 10
    updates = tasks.sample(15)
    envs = [update() for update in updates]
    action = envs[0].action_space.sample()
    rewards = [env.step(action)[1] for env in envs]
    assert np.var(rewards) > 0
    env = envs[0]
    env.close = unittest.mock.MagicMock(name='env.close')
    updates[-1](env)
    env.close.assert_called_with()
Beispiel #8
0
def maml_trpo(ctxt, seed, epochs, rollouts_per_task, meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        rollouts_per_task (int): Number of rollouts per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    env = GarageEnv(
        normalize(ML10.get_train_tasks(), expected_action_scale=10.))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(100, 100),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    max_path_length = 100

    runner = LocalRunner(ctxt)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=max_path_length,
                    meta_batch_size=meta_batch_size,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1)

    runner.setup(algo, env)
    runner.train(n_epochs=epochs,
                 batch_size=rollouts_per_task * max_path_length)
Beispiel #9
0
def rl2_ppo_ml10(ctxt, seed, max_path_length, meta_batch_size, n_epochs,
                 episode_per_task):
    """Train PPO with ML10 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        max_path_length (int): Maximum length of a single rollout.
        meta_batch_size (int): Meta batch size.
        n_epochs (int): Total number of epochs for training.
        episode_per_task (int): Number of training episode per task.

    """
    set_seed(seed)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        ML_train_envs = [
            RL2Env(ML10.from_task(task_name))
            for task_name in ML10.get_train_tasks().all_task_names
        ]
        tasks = task_sampler.EnvPoolSampler(ML_train_envs)
        tasks.grow_pool(meta_batch_size)

        env_spec = ML_train_envs[0].spec
        policy = GaussianGRUPolicy(name='policy',
                                   hidden_dim=64,
                                   env_spec=env_spec,
                                   state_include_action=False)

        baseline = LinearFeatureBaseline(env_spec=env_spec)

        algo = RL2PPO(rl2_max_path_length=max_path_length,
                      meta_batch_size=meta_batch_size,
                      task_sampler=tasks,
                      env_spec=env_spec,
                      policy=policy,
                      baseline=baseline,
                      discount=0.99,
                      gae_lambda=0.95,
                      lr_clip_range=0.2,
                      optimizer_args=dict(
                          batch_size=32,
                          max_epochs=10,
                      ),
                      stop_entropy_gradient=True,
                      entropy_method='max',
                      policy_ent_coeff=0.02,
                      center_adv=False,
                      max_path_length=max_path_length * episode_per_task)

        runner.setup(algo,
                     tasks.sample(meta_batch_size),
                     sampler_cls=LocalSampler,
                     n_workers=meta_batch_size,
                     worker_class=RL2Worker,
                     worker_args=dict(n_paths_per_trial=episode_per_task))

        runner.train(n_epochs=n_epochs,
                     batch_size=episode_per_task * max_path_length *
                     meta_batch_size)
Beispiel #10
0
def run_metarl(env, test_env, seed, log_dir):
    """Create metarl model and training."""

    deterministic.set_seed(seed)
    snapshot_config = SnapshotConfig(snapshot_dir=log_dir,
                                     snapshot_mode='gap',
                                     snapshot_gap=10)
    runner = LocalRunner(snapshot_config)

    obs_dim = int(np.prod(env[0]().observation_space.shape))
    action_dim = int(np.prod(env[0]().action_space.shape))
    reward_dim = 1

    # instantiate networks
    encoder_in_dim = obs_dim + action_dim + reward_dim
    encoder_out_dim = params['latent_size'] * 2
    net_size = params['net_size']

    context_encoder = MLPEncoder(input_dim=encoder_in_dim,
                                 output_dim=encoder_out_dim,
                                 hidden_sizes=[200, 200, 200])

    space_a = akro.Box(low=-1,
                       high=1,
                       shape=(obs_dim + params['latent_size'], ),
                       dtype=np.float32)
    space_b = akro.Box(low=-1, high=1, shape=(action_dim, ), dtype=np.float32)
    augmented_env = EnvSpec(space_a, space_b)

    qf1 = ContinuousMLPQFunction(env_spec=augmented_env,
                                 hidden_sizes=[net_size, net_size, net_size])

    qf2 = ContinuousMLPQFunction(env_spec=augmented_env,
                                 hidden_sizes=[net_size, net_size, net_size])

    obs_space = akro.Box(low=-1, high=1, shape=(obs_dim, ), dtype=np.float32)
    action_space = akro.Box(low=-1,
                            high=1,
                            shape=(params['latent_size'], ),
                            dtype=np.float32)
    vf_env = EnvSpec(obs_space, action_space)

    vf = ContinuousMLPQFunction(env_spec=vf_env,
                                hidden_sizes=[net_size, net_size, net_size])

    policy = TanhGaussianMLPPolicy2(
        env_spec=augmented_env, hidden_sizes=[net_size, net_size, net_size])

    context_conditioned_policy = ContextConditionedPolicy(
        latent_dim=params['latent_size'],
        context_encoder=context_encoder,
        policy=policy,
        use_ib=params['use_information_bottleneck'],
        use_next_obs=params['use_next_obs_in_context'],
    )

    train_task_names = ML10.get_train_tasks()._task_names
    test_task_names = ML10.get_test_tasks()._task_names

    pearlsac = PEARLSAC(
        env=env,
        test_env=test_env,
        policy=context_conditioned_policy,
        qf1=qf1,
        qf2=qf2,
        vf=vf,
        num_train_tasks=params['num_train_tasks'],
        num_test_tasks=params['num_test_tasks'],
        latent_dim=params['latent_size'],
        meta_batch_size=params['meta_batch_size'],
        num_steps_per_epoch=params['num_steps_per_epoch'],
        num_initial_steps=params['num_initial_steps'],
        num_tasks_sample=params['num_tasks_sample'],
        num_steps_prior=params['num_steps_prior'],
        num_extra_rl_steps_posterior=params['num_extra_rl_steps_posterior'],
        num_evals=params['num_evals'],
        num_steps_per_eval=params['num_steps_per_eval'],
        batch_size=params['batch_size'],
        embedding_batch_size=params['embedding_batch_size'],
        embedding_mini_batch_size=params['embedding_mini_batch_size'],
        max_path_length=params['max_path_length'],
        reward_scale=params['reward_scale'],
        train_task_names=train_task_names,
        test_task_names=test_task_names,
    )

    tu.set_gpu_mode(params['use_gpu'], gpu_id=0)
    if params['use_gpu']:
        pearlsac.to()

    tabular_log_file = osp.join(log_dir, 'progress.csv')
    tensorboard_log_dir = osp.join(log_dir)
    dowel_logger.add_output(dowel.StdOutput())
    dowel_logger.add_output(dowel.CsvOutput(tabular_log_file))
    dowel_logger.add_output(dowel.TensorBoardOutput(tensorboard_log_dir))

    runner.setup(algo=pearlsac,
                 env=env,
                 sampler_cls=PEARLSampler,
                 sampler_args=dict(max_path_length=params['max_path_length']))
    runner.train(n_epochs=params['num_epochs'],
                 batch_size=params['batch_size'])

    dowel_logger.remove_all()

    return tabular_log_file
Beispiel #11
0
def test_random_init_train():
    """Test that random_init == True for all envs."""
    env = ML10(env_type='train')
    assert len(env._task_envs) == 10
    for task_env in env._task_envs:
        assert task_env.random_init
Beispiel #12
0
    def _thunk():
        if env_id.startswith("dm"):
            _, domain, task = env_id.split('.')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        elif env_id.startswith('metaworld_'):
            world_bench = env_id.split('_')[1]
            if world_bench.startswith('ml1.'):
                world_task = world_bench.split('.')[1]
                env = ML1.get_train_tasks(world_task)
            elif world_bench == 'ml10':
                env = ML10.get_train_tasks()
            elif world_bench == 'mt10':
                env = MT10.get_train_tasks()
            else:
                raise 'This code only supports metaworld ml1, ml10 or mt10.'

            env = MetaworldWrapper(env)
        else:
            env = gym.make(env_id)

        if obs_keys is not None:
            env = gym.wrappers.FlattenDictWrapper(env, dict_keys=obs_keys)

        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)

        env.seed(seed + rank)

        obs_shape = env.observation_space.shape

        # if str(env.__class__.__name__).find('TimeLimit') >= 0:
        #     env = TimeLimitMask(env)

        if log_dir is not None:
            if save_video:
                env = bench.Monitor(env,
                                    os.path.join(log_dir + '/eval/monitor',
                                                 str(rank)),
                                    allow_early_resets=allow_early_resets)

                env = gym.wrappers.Monitor(env,
                                           os.path.join(
                                               log_dir + '/eval/video',
                                               str(rank)),
                                           force=True)
            else:
                env = bench.Monitor(env,
                                    os.path.join(log_dir + '/monitor',
                                                 str(rank)),
                                    allow_early_resets=allow_early_resets)

        if is_atari:
            if len(env.observation_space.shape) == 3:
                env = wrap_deepmind(env)
        elif len(env.observation_space.shape) == 3:
            raise NotImplementedError(
                "CNN models work only for atari,\n"
                "please use a custom wrapper for a custom pixel input env.\n"
                "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env, op=[2, 0, 1])

        return env
Beispiel #13
0
def run_task(snapshot_config, *_):
    """Set up environment and algorithm and run the task.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        _ : Unused parameters

    """
    # create multi-task environment and sample tasks
    ML_train_envs = [
        TaskIdWrapper(MetaRLEnv(
            normalize(
                env(*ML10_ARGS['train'][task]['args'],
                    **ML10_ARGS['train'][task]['kwargs']))),
                      task_id=task_id,
                      task_name=task)
        for (task_id, (task, env)) in enumerate(ML10_ENVS['train'].items())
    ]

    ML_test_envs = [
        TaskIdWrapper(MetaRLEnv(
            normalize(
                env(*ML10_ARGS['test'][task]['args'],
                    **ML10_ARGS['test'][task]['kwargs']))),
                      task_id=task_id,
                      task_name=task)
        for (task_id, (task, env)) in enumerate(ML10_ENVS['test'].items())
    ]

    train_task_names = ML10.get_train_tasks()._task_names
    test_task_names = ML10.get_test_tasks()._task_names

    env_sampler = EnvPoolSampler(ML_train_envs)
    env = env_sampler.sample(params['num_train_tasks'])
    test_env_sampler = EnvPoolSampler(ML_test_envs)
    test_env = test_env_sampler.sample(params['num_test_tasks'])

    runner = LocalRunner(snapshot_config)
    obs_dim = int(np.prod(env[0]().observation_space.shape))
    action_dim = int(np.prod(env[0]().action_space.shape))
    reward_dim = 1

    # instantiate networks
    encoder_in_dim = obs_dim + action_dim + reward_dim
    encoder_out_dim = params['latent_size'] * 2
    net_size = params['net_size']

    context_encoder = MLPEncoder(input_dim=encoder_in_dim,
                                 output_dim=encoder_out_dim,
                                 hidden_sizes=[200, 200, 200])

    space_a = akro.Box(low=-1,
                       high=1,
                       shape=(obs_dim + params['latent_size'], ),
                       dtype=np.float32)
    space_b = akro.Box(low=-1, high=1, shape=(action_dim, ), dtype=np.float32)
    augmented_env = EnvSpec(space_a, space_b)

    qf1 = ContinuousMLPQFunction(env_spec=augmented_env,
                                 hidden_sizes=[net_size, net_size, net_size])

    qf2 = ContinuousMLPQFunction(env_spec=augmented_env,
                                 hidden_sizes=[net_size, net_size, net_size])

    obs_space = akro.Box(low=-1, high=1, shape=(obs_dim, ), dtype=np.float32)
    action_space = akro.Box(low=-1,
                            high=1,
                            shape=(params['latent_size'], ),
                            dtype=np.float32)
    vf_env = EnvSpec(obs_space, action_space)

    vf = ContinuousMLPQFunction(env_spec=vf_env,
                                hidden_sizes=[net_size, net_size, net_size])

    policy = TanhGaussianMLPPolicy2(
        env_spec=augmented_env, hidden_sizes=[net_size, net_size, net_size])

    context_conditioned_policy = ContextConditionedPolicy(
        latent_dim=params['latent_size'],
        context_encoder=context_encoder,
        policy=policy,
        use_ib=params['use_information_bottleneck'],
        use_next_obs=params['use_next_obs_in_context'],
    )

    pearlsac = PEARLSAC(
        env=env,
        test_env=test_env,
        policy=context_conditioned_policy,
        qf1=qf1,
        qf2=qf2,
        vf=vf,
        num_train_tasks=params['num_train_tasks'],
        num_test_tasks=params['num_test_tasks'],
        latent_dim=params['latent_size'],
        meta_batch_size=params['meta_batch_size'],
        num_steps_per_epoch=params['num_steps_per_epoch'],
        num_initial_steps=params['num_initial_steps'],
        num_tasks_sample=params['num_tasks_sample'],
        num_steps_prior=params['num_steps_prior'],
        num_extra_rl_steps_posterior=params['num_extra_rl_steps_posterior'],
        num_evals=params['num_evals'],
        num_steps_per_eval=params['num_steps_per_eval'],
        batch_size=params['batch_size'],
        embedding_batch_size=params['embedding_batch_size'],
        embedding_mini_batch_size=params['embedding_mini_batch_size'],
        max_path_length=params['max_path_length'],
        reward_scale=params['reward_scale'],
        train_task_names=train_task_names,
        test_task_names=test_task_names,
    )

    tu.set_gpu_mode(params['use_gpu'], gpu_id=0)
    if params['use_gpu']:
        pearlsac.to()

    runner.setup(algo=pearlsac,
                 env=env,
                 sampler_cls=PEARLSampler,
                 sampler_args=dict(max_path_length=params['max_path_length']))
    runner.train(n_epochs=params['num_epochs'],
                 batch_size=params['batch_size'])
Beispiel #14
0
def torch_pearl_ml10(ctxt=None,
                     seed=1,
                     num_epochs=1000,
                     num_train_tasks=10,
                     num_test_tasks=5,
                     latent_size=7,
                     encoder_hidden_size=200,
                     net_size=300,
                     meta_batch_size=16,
                     num_steps_per_epoch=4000,
                     num_initial_steps=4000,
                     num_tasks_sample=15,
                     num_steps_prior=750,
                     num_extra_rl_steps_posterior=750,
                     batch_size=256,
                     embedding_batch_size=64,
                     embedding_mini_batch_size=64,
                     max_path_length=150,
                     reward_scale=10.,
                     use_gpu=False):
    """Train PEARL with ML10 environments.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        num_epochs (int): Number of training epochs.
        num_train_tasks (int): Number of tasks for training.
        num_test_tasks (int): Number of tasks for testing.
        latent_size (int): Size of latent context vector.
        encoder_hidden_size (int): Output dimension of dense layer of the
            context encoder.
        net_size (int): Output dimension of a dense layer of Q-function and
            value function.
        meta_batch_size (int): Meta batch size.
        num_steps_per_epoch (int): Number of iterations per epoch.
        num_initial_steps (int): Number of transitions obtained per task before
            training.
        num_tasks_sample (int): Number of random tasks to obtain data for each
            iteration.
        num_steps_prior (int): Number of transitions to obtain per task with
            z ~ prior.
        num_extra_rl_steps_posterior (int): Number of additional transitions
            to obtain per task with z ~ posterior that are only used to train
            the policy and NOT the encoder.
        batch_size (int): Number of transitions in RL batch.
        embedding_batch_size (int): Number of transitions in context batch.
        embedding_mini_batch_size (int): Number of transitions in mini context
            batch; should be same as embedding_batch_size for non-recurrent
            encoder.
        max_path_length (int): Maximum path length.
        reward_scale (int): Reward scale.
        use_gpu (bool): Whether or not to use GPU for training.

    """
    set_seed(seed)
    encoder_hidden_sizes = (encoder_hidden_size, encoder_hidden_size,
                            encoder_hidden_size)
    # create multi-task environment and sample tasks
    ML_train_envs = [
        GarageEnv(normalize(ML10.from_task(task_name)))
        for task_name in ML10.get_train_tasks().all_task_names
    ]

    ML_test_envs = [
        GarageEnv(normalize(ML10.from_task(task_name)))
        for task_name in ML10.get_test_tasks().all_task_names
    ]

    env_sampler = EnvPoolSampler(ML_train_envs)
    env = env_sampler.sample(num_train_tasks)
    test_env_sampler = EnvPoolSampler(ML_test_envs)

    runner = LocalRunner(ctxt)

    # instantiate networks
    augmented_env = PEARL.augment_env_spec(env[0](), latent_size)
    qf = ContinuousMLPQFunction(env_spec=augmented_env,
                                hidden_sizes=[net_size, net_size, net_size])

    vf_env = PEARL.get_env_spec(env[0](), latent_size, 'vf')
    vf = ContinuousMLPQFunction(env_spec=vf_env,
                                hidden_sizes=[net_size, net_size, net_size])

    inner_policy = TanhGaussianMLPPolicy(
        env_spec=augmented_env, hidden_sizes=[net_size, net_size, net_size])

    pearl = PEARL(
        env=env,
        policy_class=ContextConditionedPolicy,
        encoder_class=MLPEncoder,
        inner_policy=inner_policy,
        qf=qf,
        vf=vf,
        num_train_tasks=num_train_tasks,
        num_test_tasks=num_test_tasks,
        latent_dim=latent_size,
        encoder_hidden_sizes=encoder_hidden_sizes,
        test_env_sampler=test_env_sampler,
        meta_batch_size=meta_batch_size,
        num_steps_per_epoch=num_steps_per_epoch,
        num_initial_steps=num_initial_steps,
        num_tasks_sample=num_tasks_sample,
        num_steps_prior=num_steps_prior,
        num_extra_rl_steps_posterior=num_extra_rl_steps_posterior,
        batch_size=batch_size,
        embedding_batch_size=embedding_batch_size,
        embedding_mini_batch_size=embedding_mini_batch_size,
        max_path_length=max_path_length,
        reward_scale=reward_scale,
    )

    tu.set_gpu_mode(use_gpu, gpu_id=0)
    if use_gpu:
        pearl.to()

    runner.setup(algo=pearl,
                 env=env[0](),
                 sampler_cls=LocalSampler,
                 sampler_args=dict(max_path_length=max_path_length),
                 n_workers=1,
                 worker_class=PEARLWorker)

    runner.train(n_epochs=num_epochs, batch_size=batch_size)