Example #1
0
def test_pickle_meta_evaluator():
    set_seed(100)
    tasks = SetTaskSampler(lambda: GarageEnv(PointEnv()))
    max_path_length = 200
    env = GarageEnv(PointEnv())
    n_traj = 3
    with tempfile.TemporaryDirectory() as log_dir_name:
        runner = LocalRunner(
            SnapshotConfig(snapshot_dir=log_dir_name,
                           snapshot_mode='last',
                           snapshot_gap=1))
        meta_eval = MetaEvaluator(test_task_sampler=tasks,
                                  max_path_length=max_path_length,
                                  n_test_tasks=10,
                                  n_exploration_traj=n_traj)
        policy = RandomPolicy(env.spec.action_space)
        algo = MockAlgo(env, policy, max_path_length, n_traj, meta_eval)
        runner.setup(algo, env)
        log_file = tempfile.NamedTemporaryFile()
        csv_output = CsvOutput(log_file.name)
        logger.add_output(csv_output)
        meta_eval.evaluate(algo)
        meta_eval_pickle = cloudpickle.dumps(meta_eval)
        meta_eval2 = cloudpickle.loads(meta_eval_pickle)
        meta_eval2.evaluate(algo)
Example #2
0
def test_meta_evaluator_with_tf():
    set_seed(100)
    tasks = SetTaskSampler(PointEnv, wrapper=set_length)
    max_episode_length = 200
    env = PointEnv()
    n_eps = 3
    with tempfile.TemporaryDirectory() as log_dir_name:
        ctxt = SnapshotConfig(snapshot_dir=log_dir_name,
                              snapshot_mode='none',
                              snapshot_gap=1)
        with TFTrainer(ctxt) as trainer:
            meta_eval = MetaEvaluator(test_task_sampler=tasks,
                                      n_test_tasks=10,
                                      n_exploration_eps=n_eps)
            policy = GaussianMLPPolicy(env.spec)
            algo = MockAlgo(env, policy, max_episode_length, n_eps, meta_eval)
            trainer.setup(algo, env)
            log_file = tempfile.NamedTemporaryFile()
            csv_output = CsvOutput(log_file.name)
            logger.add_output(csv_output)
            meta_eval.evaluate(algo)
            algo_pickle = cloudpickle.dumps(algo)
        tf.compat.v1.reset_default_graph()
        with TFTrainer(ctxt) as trainer:
            algo2 = cloudpickle.loads(algo_pickle)
            trainer.setup(algo2, env)
            trainer.train(10, 0)
Example #3
0
def test_meta_evaluator_with_tf():
    set_seed(100)
    tasks = SetTaskSampler(lambda: GarageEnv(PointEnv()))
    max_path_length = 200
    env = GarageEnv(PointEnv())
    n_traj = 3
    with tempfile.TemporaryDirectory() as log_dir_name:
        ctxt = SnapshotConfig(snapshot_dir=log_dir_name,
                              snapshot_mode='none',
                              snapshot_gap=1)
        with LocalTFRunner(ctxt) as runner:
            meta_eval = MetaEvaluator(test_task_sampler=tasks,
                                      max_path_length=max_path_length,
                                      n_test_tasks=10,
                                      n_exploration_traj=n_traj)
            policy = GaussianMLPPolicy(env.spec)
            algo = MockAlgo(env, policy, max_path_length, n_traj, meta_eval)
            runner.setup(algo, env)
            log_file = tempfile.NamedTemporaryFile()
            csv_output = CsvOutput(log_file.name)
            logger.add_output(csv_output)
            meta_eval.evaluate(algo)
            algo_pickle = cloudpickle.dumps(algo)
        with tf.Graph().as_default():
            with LocalTFRunner(ctxt) as runner:
                algo2 = cloudpickle.loads(algo_pickle)
                runner.setup(algo2, env)
                runner.train(10, 0)
Example #4
0
    def test_ppo_pendulum(self):
        """Test PPO with Pendulum environment."""
        deterministic.set_seed(0)

        rollouts_per_task = 5
        max_path_length = 100

        task_sampler = SetTaskSampler(lambda: GarageEnv(
            normalize(HalfCheetahDirEnv(), expected_action_scale=10.)))

        meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler,
                                       max_path_length=max_path_length,
                                       n_test_tasks=1,
                                       n_test_rollouts=10)

        runner = LocalRunner(snapshot_config)
        algo = MAMLVPG(env=self.env,
                       policy=self.policy,
                       value_function=self.value_function,
                       max_path_length=max_path_length,
                       meta_batch_size=5,
                       discount=0.99,
                       gae_lambda=1.,
                       inner_lr=0.1,
                       num_grad_updates=1,
                       meta_evaluator=meta_evaluator)

        runner.setup(algo, self.env)
        last_avg_ret = runner.train(n_epochs=10,
                                    batch_size=rollouts_per_task *
                                    max_path_length)

        assert last_avg_ret > -5
Example #5
0
def maml_trpo_metaworld_ml10(ctxt, seed, epochs, rollouts_per_task,
                             meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        rollouts_per_task (int): Number of rollouts per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    env = GarageEnv(
        normalize(mwb.ML10.get_train_tasks(), expected_action_scale=10.))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(100, 100),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    max_episode_length = 100

    test_task_names = mwb.ML10.get_test_tasks().all_task_names
    test_tasks = [
        GarageEnv(
            normalize(mwb.ML10.from_task(task), expected_action_scale=10.))
        for task in test_task_names
    ]
    test_sampler = EnvPoolSampler(test_tasks)

    meta_evaluator = MetaEvaluator(test_task_sampler=test_sampler,
                                   max_episode_length=max_episode_length,
                                   n_test_tasks=len(test_task_names))

    runner = LocalRunner(ctxt)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    value_function=value_function,
                    max_episode_length=max_episode_length,
                    meta_batch_size=meta_batch_size,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1,
                    meta_evaluator=meta_evaluator)

    runner.setup(algo, env)
    runner.train(n_epochs=epochs,
                 batch_size=rollouts_per_task * max_episode_length)
Example #6
0
    def test_ppo_pendulum(self):
        """Test PPO with Pendulum environment."""
        deterministic.set_seed(0)

        episodes_per_task = 5
        max_episode_length = self.env.spec.max_episode_length

        task_sampler = SetTaskSampler(
            HalfCheetahDirEnv,
            wrapper=lambda env, _: normalize(GymEnv(
                env, max_episode_length=max_episode_length),
                                             expected_action_scale=10.))

        meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler,
                                       n_test_tasks=1,
                                       n_test_episodes=10)

        trainer = Trainer(snapshot_config)
        algo = MAMLVPG(env=self.env,
                       policy=self.policy,
                       task_sampler=self.task_sampler,
                       value_function=self.value_function,
                       meta_batch_size=5,
                       discount=0.99,
                       gae_lambda=1.,
                       inner_lr=0.1,
                       num_grad_updates=1,
                       meta_evaluator=meta_evaluator)

        trainer.setup(algo, self.env, sampler_cls=LocalSampler)
        last_avg_ret = trainer.train(n_epochs=10,
                                     batch_size=episodes_per_task *
                                     max_episode_length)

        assert last_avg_ret > -5
def maml_trpo_metaworld_ml1_push(ctxt, seed, epochs, rollouts_per_task,
                                 meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        rollouts_per_task (int): Number of rollouts per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)

    ml1 = metaworld.ML1('push-v1')
    tasks = MetaWorldTaskSampler(ml1, 'train')
    env = tasks.sample(1)[0]()
    test_sampler = SetTaskSampler(MetaWorldSetTaskEnv,
                                  env=MetaWorldSetTaskEnv(ml1, 'test'))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(100, 100),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=[32, 32],
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    meta_evaluator = MetaEvaluator(test_task_sampler=test_sampler,
                                   n_test_tasks=1,
                                   n_exploration_eps=rollouts_per_task)

    sampler = RaySampler(agents=policy,
                         envs=env,
                         max_episode_length=env.spec.max_episode_length,
                         n_workers=meta_batch_size)

    trainer = Trainer(ctxt)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    sampler=sampler,
                    task_sampler=tasks,
                    value_function=value_function,
                    meta_batch_size=meta_batch_size,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1,
                    meta_evaluator=meta_evaluator)

    trainer.setup(algo, env)
    trainer.train(n_epochs=epochs,
                  batch_size=rollouts_per_task * env.spec.max_episode_length)
Example #8
0
def load_mamlvpg(env_name="MountainCarContinuous-v0"):
    """Return an instance of the MAML-VPG algorithm."""
    env = GarageEnv(env_name=env_name)
    policy = DeterministicMLPPolicy(name='policy',
                                    env_spec=env.spec,
                                    hidden_sizes=[64, 64])
    vfunc = GaussianMLPValueFunction(env_spec=env.spec)

    task_sampler = SetTaskSampler(
        lambda: GarageEnv(normalize(env, expected_action_scale=10.)))

    max_path_length = 100
    meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler,
                                   max_path_length=max_path_length,
                                   n_test_tasks=1,
                                   n_test_rollouts=10)
    algo = MAMLVPG(env=env,
                   policy=policy,
                   value_function=vfunc,
                   max_path_length=max_path_length,
                   meta_batch_size=20,
                   discount=0.99,
                   gae_lambda=1.,
                   inner_lr=0.1,
                   num_grad_updates=1,
                   meta_evaluator=meta_evaluator)
    return algo
Example #9
0
def maml_trpo_half_cheetah_dir(ctxt, seed, epochs, episodes_per_task,
                               meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (ExperimentContext): The experiment configuration used by
            :class:`~Trainer` to create the :class:`~Snapshotter`.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        episodes_per_task (int): Number of episodes per epoch per task for
            training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    max_episode_length = 100
    env = normalize(GymEnv(HalfCheetahDirEnv(),
                           max_episode_length=max_episode_length),
                    expected_action_scale=10.)

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=[64, 64],
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=[32, 32],
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    task_sampler = SetTaskSampler(
        HalfCheetahDirEnv,
        wrapper=lambda env, _: normalize(GymEnv(
            env, max_episode_length=max_episode_length),
                                         expected_action_scale=10.))

    meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler,
                                   n_test_tasks=1,
                                   n_test_episodes=10)

    trainer = Trainer(ctxt)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    task_sampler=task_sampler,
                    value_function=value_function,
                    meta_batch_size=meta_batch_size,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1,
                    meta_evaluator=meta_evaluator)

    trainer.setup(algo, env)
    trainer.train(n_epochs=epochs,
                  batch_size=episodes_per_task * env.spec.max_episode_length)
def maml_trpo_metaworld_ml45(ctxt, seed, epochs, episodes_per_task,
                             meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (ExperimentContext): The experiment configuration used by
            :class:`~Trainer` to create the :class:`~Snapshotter`.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        episodes_per_task (int): Number of episodes per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    ml45 = metaworld.ML45()

    # pylint: disable=missing-return-doc,missing-return-type-doc
    def wrap(env, _):
        return normalize(env, expected_action_scale=10.0)

    train_task_sampler = MetaWorldTaskSampler(ml45, 'train', wrap)
    test_env = wrap(MetaWorldSetTaskEnv(ml45, 'test'), None)
    test_task_sampler = SetTaskSampler(MetaWorldSetTaskEnv,
                                       env=test_env,
                                       wrapper=wrap)
    env = train_task_sampler.sample(45)[0]()

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(100, 100),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    meta_evaluator = MetaEvaluator(test_task_sampler=test_task_sampler)

    trainer = Trainer(ctxt)
    algo = MAMLTRPO(env=env,
                    task_sampler=train_task_sampler,
                    policy=policy,
                    value_function=value_function,
                    meta_batch_size=meta_batch_size,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1,
                    meta_evaluator=meta_evaluator)

    trainer.setup(algo, env, n_workers=meta_batch_size)
    trainer.train(n_epochs=epochs,
                  batch_size=episodes_per_task * env.spec.max_episode_length)
Example #11
0
def maml_trpo(ctxt, seed, epochs, rollouts_per_task, meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        rollouts_per_task (int): Number of rollouts per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    # @TODO blowing up here...
    env = GarageEnv(normalize(HalfCheetahDirEnv(), expected_action_scale=10.))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=[64, 64],
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=[32, 32],
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    max_path_length = 100

    task_sampler = SetTaskSampler(lambda: GarageEnv(
        normalize(HalfCheetahDirEnv(), expected_action_scale=10.)))

    meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler,
                                   max_path_length=max_path_length,
                                   n_test_tasks=1,
                                   n_test_rollouts=10)

    runner = LocalRunner(ctxt)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    value_function=value_function,
                    max_path_length=max_path_length,
                    meta_batch_size=meta_batch_size,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1,
                    meta_evaluator=meta_evaluator)

    runner.setup(algo, env)
    runner.train(n_epochs=epochs,
                 batch_size=rollouts_per_task * max_path_length)
Example #12
0
def test_meta_evaluator():
    set_seed(100)
    tasks = SetTaskSampler(PointEnv, wrapper=set_length)
    max_episode_length = 200
    with tempfile.TemporaryDirectory() as log_dir_name:
        trainer = Trainer(
            SnapshotConfig(snapshot_dir=log_dir_name,
                           snapshot_mode='last',
                           snapshot_gap=1))
        env = PointEnv(max_episode_length=max_episode_length)
        algo = OptimalActionInference(env=env,
                                      max_episode_length=max_episode_length)
        trainer.setup(algo, env)
        meta_eval = MetaEvaluator(test_task_sampler=tasks, n_test_tasks=10)
        log_file = tempfile.NamedTemporaryFile()
        csv_output = CsvOutput(log_file.name)
        logger.add_output(csv_output)
        meta_eval.evaluate(algo)
        logger.log(tabular)
        meta_eval.evaluate(algo)
        logger.log(tabular)
        logger.dump_output_type(CsvOutput)
        logger.remove_output_type(CsvOutput)
        with open(log_file.name, 'r') as file:
            rows = list(csv.DictReader(file))
        assert len(rows) == 2
        assert float(
            rows[0]['MetaTest/__unnamed_task__/TerminationRate']) < 1.0
        assert float(rows[0]['MetaTest/__unnamed_task__/Iteration']) == 0
        assert (float(rows[0]['MetaTest/__unnamed_task__/MaxReturn']) >= float(
            rows[0]['MetaTest/__unnamed_task__/AverageReturn']))
        assert (float(rows[0]['MetaTest/__unnamed_task__/AverageReturn']) >=
                float(rows[0]['MetaTest/__unnamed_task__/MinReturn']))
        assert float(rows[1]['MetaTest/__unnamed_task__/Iteration']) == 1
Example #13
0
    def test_pearl_ml1_push(self):
        """Test PEARL with ML1 Push environment."""
        params = dict(seed=1,
                      num_epochs=1,
                      num_train_tasks=5,
                      num_test_tasks=1,
                      latent_size=7,
                      encoder_hidden_sizes=[10, 10, 10],
                      net_size=30,
                      meta_batch_size=16,
                      num_steps_per_epoch=40,
                      num_initial_steps=40,
                      num_tasks_sample=15,
                      num_steps_prior=15,
                      num_extra_rl_steps_posterior=15,
                      batch_size=256,
                      embedding_batch_size=8,
                      embedding_mini_batch_size=8,
                      max_path_length=50,
                      reward_scale=10.,
                      use_information_bottleneck=True,
                      use_next_obs_in_context=False,
                      use_gpu=False)

        net_size = params['net_size']
        set_seed(params['seed'])
        env_sampler = SetTaskSampler(
            lambda: GarageEnv(normalize(ML1.get_train_tasks('push-v1'))))
        env = env_sampler.sample(params['num_train_tasks'])

        test_env_sampler = SetTaskSampler(
            lambda: GarageEnv(normalize(ML1.get_test_tasks('push-v1'))))

        augmented_env = PEARL.augment_env_spec(env[0](), params['latent_size'])
        qf = ContinuousMLPQFunction(
            env_spec=augmented_env,
            hidden_sizes=[net_size, net_size, net_size])

        vf_env = PEARL.get_env_spec(env[0](), params['latent_size'], 'vf')
        vf = ContinuousMLPQFunction(
            env_spec=vf_env, hidden_sizes=[net_size, net_size, net_size])

        inner_policy = TanhGaussianMLPPolicy(
            env_spec=augmented_env,
            hidden_sizes=[net_size, net_size, net_size])

        pearl = PEARL(
            env=env,
            policy_class=ContextConditionedPolicy,
            encoder_class=MLPEncoder,
            inner_policy=inner_policy,
            qf=qf,
            vf=vf,
            num_train_tasks=params['num_train_tasks'],
            num_test_tasks=params['num_test_tasks'],
            latent_dim=params['latent_size'],
            encoder_hidden_sizes=params['encoder_hidden_sizes'],
            meta_batch_size=params['meta_batch_size'],
            num_steps_per_epoch=params['num_steps_per_epoch'],
            num_initial_steps=params['num_initial_steps'],
            num_tasks_sample=params['num_tasks_sample'],
            num_steps_prior=params['num_steps_prior'],
            num_extra_rl_steps_posterior=params[
                'num_extra_rl_steps_posterior'],
            batch_size=params['batch_size'],
            embedding_batch_size=params['embedding_batch_size'],
            embedding_mini_batch_size=params['embedding_mini_batch_size'],
            max_path_length=params['max_path_length'],
            reward_scale=params['reward_scale'],
        )

        tu.set_gpu_mode(params['use_gpu'], gpu_id=0)
        if params['use_gpu']:
            pearl.to()

        runner = LocalRunner(snapshot_config)
        runner.setup(
            algo=pearl,
            env=env[0](),
            sampler_cls=LocalSampler,
            sampler_args=dict(max_path_length=params['max_path_length']),
            n_workers=1,
            worker_class=PEARLWorker)

        worker_args = dict(deterministic=True, accum_context=True)
        meta_evaluator = MetaEvaluator(
            test_task_sampler=test_env_sampler,
            max_path_length=params['max_path_length'],
            worker_class=PEARLWorker,
            worker_args=worker_args,
            n_test_tasks=params['num_test_tasks'])
        pearl.evaluator = meta_evaluator
        runner.train(n_epochs=params['num_epochs'],
                     batch_size=params['batch_size'])
Example #14
0
    def __init__(
        self,
        env,
        skill_env,
        controller_policy,
        skill_actor,
        qf,
        vf,
        num_skills,
        num_train_tasks,
        num_test_tasks,
        test_env_sampler,
        sampler_class,  # to avoid cycling import
        controller_class=ControllerPolicy,
        controller_lr=3E-4,
        qf_lr=3E-4,
        vf_lr=3E-4,
        policy_mean_reg_coeff=1E-3,
        policy_std_reg_coeff=1E-3,
        policy_pre_activation_coeff=0.,
        soft_target_tau=0.005,
        optimizer_class=torch.optim.Adam,
        meta_batch_size=64,
        num_steps_per_epoch=1000,
        num_tasks_sample=5,
        batch_size=1024,
        embedding_batch_size=1024,
        embedding_mini_batch_size=1024,
        max_path_length=1000,
        discount=0.99,
        replay_buffer_size=1000000,
    ):

        self._env = env
        self._skill_env = skill_env
        self._qf1 = qf
        self._qf2 = copy.deepcopy(qf)
        self._vf = vf
        self._skill_actor = skill_actor
        self._num_skills = num_skills
        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks

        self._policy_mean_reg_coeff = policy_mean_reg_coeff
        self._policy_std_reg_coeff = policy_std_reg_coeff
        self._policy_pre_activation_coeff = policy_pre_activation_coeff
        self._soft_target_tau = soft_target_tau

        self._meta_batch_size = meta_batch_size
        self._num_steps_per_epoch = num_steps_per_epoch
        self._num_tasks_sample = num_tasks_sample
        self._batch_size = batch_size
        self._embedding_batch_size = embedding_batch_size
        self._embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self._discount = discount
        self._replay_buffer_size = replay_buffer_size

        self._task_idx = None
        self._skill_idx = None  # do we really need it

        self._is_resuming = False

        worker_args = dict(num_skills=num_skills,
                           skill_actor_class=type(skill_actor),
                           controller_class=controller_class,
                           deterministic=False,
                           accum_context=False)
        self._evaluator = MetaEvaluator(
            test_task_sampler=test_env_sampler,
            max_path_length=max_path_length,
            worker_class=BasicHierachWorker,
            worker_args=worker_args,
            n_test_tasks=num_test_tasks,
            sampler_class=sampler_class,
            trajectory_batch_class=SkillTrajectoryBatch)
        self._average_rewards = []

        self._controller = controller_class(
            controller_policy=controller_policy,
            num_skills=num_skills,
            sub_actor=skill_actor)

        self._skills_replay_buffer = PathBuffer(replay_buffer_size)

        self._replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self.target_vf = copy.deepcopy(self._vf)
        self.vf_criterion = torch.nn.MSELoss()

        self._controller_optimizer = optimizer_class(
            self._controller.networks[1].parameters(),
            lr=controller_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self._qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self._qf2.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self._vf.parameters(),
            lr=vf_lr,
        )
def rl2_ppo_metaworld_ml1_push(ctxt, seed, meta_batch_size, n_epochs,
                               episode_per_task):
    """Train RL2 PPO with ML1 environment.

    Args:
        ctxt (ExperimentContext): The experiment configuration used by
            :class:`~Trainer` to create the :class:`~Snapshotter`.
        seed (int): Used to seed the random number generator to produce
            determinism.
        meta_batch_size (int): Meta batch size.
        n_epochs (int): Total number of epochs for training.
        episode_per_task (int): Number of training episode per task.

    """
    set_seed(seed)
    ml1 = metaworld.ML1('push-v1')

    task_sampler = MetaWorldTaskSampler(ml1, 'train',
                                        lambda env, _: RL2Env(env))
    env = task_sampler.sample(1)[0]()
    test_task_sampler = SetTaskSampler(MetaWorldSetTaskEnv,
                                       env=MetaWorldSetTaskEnv(ml1, 'test'),
                                       wrapper=lambda env, _: RL2Env(env))
    env_spec = env.spec

    with TFTrainer(snapshot_config=ctxt) as trainer:
        policy = GaussianGRUPolicy(name='policy',
                                   hidden_dim=64,
                                   env_spec=env_spec,
                                   state_include_action=False)

        meta_evaluator = MetaEvaluator(test_task_sampler=test_task_sampler)

        baseline = LinearFeatureBaseline(env_spec=env_spec)

        algo = RL2PPO(meta_batch_size=meta_batch_size,
                      task_sampler=task_sampler,
                      env_spec=env_spec,
                      policy=policy,
                      baseline=baseline,
                      discount=0.99,
                      gae_lambda=0.95,
                      lr_clip_range=0.2,
                      optimizer_args=dict(batch_size=32,
                                          max_optimization_epochs=10),
                      stop_entropy_gradient=True,
                      entropy_method='max',
                      policy_ent_coeff=0.02,
                      center_adv=False,
                      meta_evaluator=meta_evaluator,
                      episodes_per_trial=episode_per_task)

        trainer.setup(algo,
                      task_sampler.sample(meta_batch_size),
                      sampler_cls=LocalSampler,
                      n_workers=meta_batch_size,
                      worker_class=RL2Worker,
                      worker_args=dict(n_episodes_per_trial=episode_per_task))

        trainer.train(n_epochs=n_epochs,
                      batch_size=episode_per_task *
                      env_spec.max_episode_length * meta_batch_size)
Example #16
0
class MetaKant(MetaRLAlgorithm):
    def __init__(
        self,
        env,
        skill_env,
        controller_policy,
        skill_actor,
        qf,
        vf,
        num_skills,
        num_train_tasks,
        num_test_tasks,
        latent_dim,
        encoder_hidden_sizes,
        test_env_sampler,
        sampler_class,  # to avoid cycling import
        controller_class=OpenContextConditionedControllerPolicy,
        encoder_class=GaussianContextEncoder,
        encoder_module_class=MLPEncoder,
        is_encoder_recurrent=False,
        controller_lr=3E-4,
        qf_lr=3E-4,
        vf_lr=3E-4,
        context_lr=3E-4,
        policy_mean_reg_coeff=1E-3,
        policy_std_reg_coeff=1E-3,
        policy_pre_activation_coeff=0.,
        soft_target_tau=0.005,
        kl_lambda=.1,
        optimizer_class=torch.optim.Adam,
        use_next_obs_in_context=False,
        meta_batch_size=64,
        num_steps_per_epoch=1000,
        num_skills_reason_steps=1000,
        num_skills_sample=10,
        num_initial_steps=1500,
        num_tasks_sample=5,
        num_steps_prior=400,
        num_steps_posterior=0,
        num_extra_rl_steps_posterior=600,
        batch_size=1024,
        embedding_batch_size=1024,
        embedding_mini_batch_size=1024,
        max_path_length=1000,
        discount=0.99,
        replay_buffer_size=1000000,
        # TODO: the ratio needs to be tuned
        skills_reason_reward_scale=1,
        tasks_adapt_reward_scale=1.2,
        use_information_bottleneck=True,
        update_post_train=1,
    ):

        self._env = env
        self._skill_env = skill_env
        self._qf1 = qf
        self._qf2 = copy.deepcopy(qf)
        self._vf = vf
        self._skill_actor = skill_actor
        self._num_skills = num_skills
        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks
        self._latent_dim = latent_dim

        self._policy_mean_reg_coeff = policy_mean_reg_coeff
        self._policy_std_reg_coeff = policy_std_reg_coeff
        self._policy_pre_activation_coeff = policy_pre_activation_coeff
        self._soft_target_tau = soft_target_tau
        self._kl_lambda = kl_lambda
        self._use_next_obs_in_context = use_next_obs_in_context

        self._meta_batch_size = meta_batch_size
        self._num_skills_reason_steps = num_skills_reason_steps
        self._num_steps_per_epoch = num_steps_per_epoch
        self._num_initial_steps = num_initial_steps
        self._num_tasks_sample = num_tasks_sample
        self._num_skills_sample = num_skills_sample
        self._num_steps_prior = num_steps_prior
        self._num_steps_posterior = num_steps_posterior
        self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior
        self._batch_size = batch_size
        self._embedding_batch_size = embedding_batch_size
        self._embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self._discount = discount
        self._replay_buffer_size = replay_buffer_size

        self._is_encoder_recurrent = is_encoder_recurrent
        self._use_information_bottleneck = use_information_bottleneck
        self._skills_reason_reward_scale = skills_reason_reward_scale
        self._tasks_adapt_reward_scale = tasks_adapt_reward_scale
        self._update_post_train = update_post_train
        self._task_idx = None
        self._skill_idx = None  # do we really need it

        self._is_resuming = False

        worker_args = dict(num_skills=num_skills,
                           skill_actor_class=type(skill_actor),
                           controller_class=controller_class,
                           deterministic=False,
                           accum_context=True)
        self._evaluator = MetaEvaluator(
            test_task_sampler=test_env_sampler,
            max_path_length=max_path_length,
            worker_class=KantWorker,
            worker_args=worker_args,
            n_test_tasks=num_test_tasks,
            sampler_class=sampler_class,
            trajectory_batch_class=SkillTrajectoryBatch)
        self._average_rewards = []

        encoder_spec = self.get_env_spec(env[0](), latent_dim, num_skills,
                                         'encoder')
        encoder_in_dim = int(np.prod(encoder_spec.input_space.shape))
        encoder_out_dim = int(np.prod(encoder_spec.output_space.shape))

        encoder_module = encoder_module_class(
            input_dim=encoder_in_dim,
            output_dim=encoder_out_dim,
            hidden_sizes=encoder_hidden_sizes)

        context_encoder = encoder_class(encoder_module,
                                        use_information_bottleneck, latent_dim)

        self._controller = controller_class(
            latent_dim=latent_dim,
            context_encoder=context_encoder,
            controller_policy=controller_policy,
            num_skills=num_skills,
            sub_actor=skill_actor,
            use_information_bottleneck=use_information_bottleneck,
            use_next_obs=use_next_obs_in_context)

        self._skills_replay_buffer = PathBuffer(replay_buffer_size)

        self._replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self._context_replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self.target_vf = copy.deepcopy(self._vf)
        self.vf_criterion = torch.nn.MSELoss()

        self._controller_optimizer = optimizer_class(
            self._controller.networks[1].parameters(),
            lr=controller_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self._qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self._qf2.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self._vf.parameters(),
            lr=vf_lr,
        )
        self.context_optimizer = optimizer_class(
            self._controller.networks[0].parameters(),
            lr=context_lr,
        )

    def train(self, runner):
        for _ in runner.step_epochs():
            epoch = runner.step_itr / self._num_steps_per_epoch

            if epoch == 0 or self._is_resuming:
                for idx in range(self._num_skills):
                    self._skill_idx = idx
                    self._obtain_skill_samples(runner, epoch,
                                               self._num_initial_steps)
                for idx in range(self._num_train_tasks):
                    self._task_idx = idx
                    self._obtain_task_samples(runner, epoch,
                                              self._num_initial_steps, np.inf)
                self._is_resuming = False

            logger.log('Sampling skills')
            for idx in range(self._num_skills_sample):
                self._skill_idx = idx
                # self._skills_replay_buffer.clear()
                self._obtain_skill_samples(runner, epoch,
                                           self._num_skills_reason_steps)

            logger.log('Training skill reasoning...')
            self._skills_reason_train_once()

            logger.log('Sampling tasks')
            for _ in range(self._num_tasks_sample):
                idx = np.random.randint(self._num_train_tasks)
                self._task_idx = idx
                self._context_replay_buffers[idx].clear()
                # obtain samples with z ~ prior
                logger.log("Obtaining samples with z ~ prior")
                if self._num_steps_prior > 0:
                    self._obtain_task_samples(runner, epoch,
                                              self._num_steps_prior, np.inf)
                # obtain samples with z ~ posterior
                logger.log("Obtaining samples with z ~ posterior")
                if self._num_steps_posterior > 0:
                    self._obtain_task_samples(runner, epoch,
                                              self._num_steps_posterior,
                                              self._update_post_train)
                # obtain extras samples for RL training but not encoder
                logger.log(
                    "Obtaining extra samples for RL traing but not encoder")
                if self._num_extra_rl_steps_posterior > 0:
                    self._obtain_task_samples(
                        runner,
                        epoch,
                        self._num_extra_rl_steps_posterior,
                        self._update_post_train,
                        add_to_enc_buffer=False)

            logger.log('Training task adapting...')
            self._tasks_adapt_train_once()

            runner.step_itr += 1

            logger.log('Evaluating...')
            # evaluate
            self._controller.reset_belief()
            self._average_rewards.append(self._evaluator.evaluate(self))

        return self._average_rewards

    def _skills_reason_train_once(self):
        for _ in range(self._num_steps_per_epoch):
            self._skills_reason_optimize_policy()

    def _tasks_adapt_train_once(self):
        for _ in range(self._num_steps_per_epoch):
            indices = np.random.choice(range(self._num_train_tasks),
                                       self._meta_batch_size)
            self._tasks_adapt_optimize_policy(indices)

    def _skills_reason_optimize_policy(self):
        self._controller.reset_belief()

        # data shape is (task, batch, feat)
        obs, actions, rewards, skills, next_obs, terms, context = self.\
            _sample_skill_path()

        # skills_pred is distribution
        policy_outputs, skills_pred, task_z = self._controller(obs, context)

        _, policy_mean, policy_log_std, policy_log_pi = policy_outputs[:4]

        self.context_optimizer.zero_grad()
        if self._use_information_bottleneck:
            kl_div = self._controller.compute_kl_div()
            kl_loss = self._kl_lambda * kl_div
            kl_loss.backward(retain_graph=True)

        skills_target = skills.clone().detach().requires_grad_(True)\
            .to(tu.global_device())
        skills_pred = skills_pred.to(tu.global_device())

        policy_loss = F.mse_loss(skills_pred.flatten(), skills_target.flatten())\
                      * self._skills_reason_reward_scale

        mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean()
        std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean()

        #took away the pre-activation reg term
        policy_reg_loss = mean_reg_loss + std_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        self._controller_optimizer.zero_grad()
        policy_loss.backward()
        self._controller_optimizer.step()

    def _tasks_adapt_optimize_policy(self, indices):
        num_tasks = len(indices)
        obs, actions, rewards, skills, next_obs, terms, context = \
            self._sample_task_path(indices)
        self._controller.reset_belief(num_tasks=num_tasks)

        # data shape is (task, batch, feat)
        # new_skills_pred is distribution
        policy_outputs, new_skills_pred, task_z = self._controller(
            obs, context)
        new_actions, policy_mean, policy_log_std, policy_log_pi = policy_outputs[:
                                                                                 4]

        # flatten out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        skills = skills.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)

        # optimize qf and encoder networks
        # TODO try [obs, skills, actions] or [obs, skills, task_z]
        # FIXME prob need to reshape or tile task_z
        obs = obs.to(tu.global_device())
        skills = skills.to(tu.global_device())
        next_obs = next_obs.to(tu.global_device())

        q1_pred = self._qf1(torch.cat([obs, skills], dim=1), task_z)
        q2_pred = self._qf2(torch.cat([obs, skills], dim=1), task_z)
        v_pred = self._vf(obs, task_z.detach())

        with torch.no_grad():
            target_v_values = self.target_vf(next_obs, task_z)

        # KL constraint on z if probabilistic
        self.context_optimizer.zero_grad()
        if self._use_information_bottleneck:
            kl_div = self._controller.compute_kl_div()
            kl_loss = self._kl_lambda * kl_div
            kl_loss.backward(retain_graph=True)

        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()

        rewards_flat = rewards.view(b * t, -1)
        rewards_flat = rewards_flat * self._tasks_adapt_reward_scale
        terms_flat = terms.view(b * t, -1)
        q_target = rewards_flat + (
            1. - terms_flat) * self._discount * target_v_values
        qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean(
            (q2_pred - q_target)**2)
        qf_loss.backward()

        self.qf1_optimizer.step()
        self.qf2_optimizer.step()
        self.context_optimizer.step()

        new_skills_pred = new_skills_pred.to(tu.global_device())

        # compute min Q on the new actions
        q1 = self._qf1(torch.cat([obs, new_skills_pred], dim=1),
                       task_z.detach())
        q2 = self._qf2(torch.cat([obs, new_skills_pred], dim=1),
                       task_z.detach())
        min_q = torch.min(q1, q2)

        # optimize vf
        policy_log_pi = policy_log_pi.to(tu.global_device())

        v_target = min_q - policy_log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()
        self._update_target_network()

        # optimize policy
        log_policy_target = min_q
        policy_loss = (policy_log_pi - log_policy_target).mean()

        mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean()
        std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean()

        # took away pre-activation reg
        policy_reg_loss = mean_reg_loss + std_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        self._controller_optimizer.zero_grad()
        policy_loss.backward()
        self._controller_optimizer.step()

    def _obtain_skill_samples(self, runner, itr, num_paths):
        self._controller.reset_belief()
        total_paths = 0

        while total_paths < num_paths:
            num_samples = num_paths * self.max_path_length
            paths = runner.obtain_samples(itr, num_samples, self._skill_actor,
                                          self._skill_env)
            total_paths += len(paths)

            for path in paths:
                p = {
                    'states': path['states'],
                    'actions': path['actions'],
                    'env_rewards': path['env_rewards'].reshape(-1, 1),
                    'skills_onehot': path['skills_onehot'],
                    'next_states': path['next_states'],
                    'dones': path['dones'].reshape(-1, 1)
                }
                self._skills_replay_buffer.add_path(p)

    def _obtain_task_samples(self,
                             runner,
                             itr,
                             num_paths,
                             update_posterior_rate,
                             add_to_enc_buffer=True):
        self._controller.reset_belief()
        total_paths = 0

        if update_posterior_rate != np.inf:
            num_paths_per_batch = update_posterior_rate
        else:
            num_paths_per_batch = num_paths

        while total_paths < num_paths:
            num_samples = num_paths_per_batch * self.max_path_length
            paths = runner.obtain_samples(itr, num_samples, self._controller,
                                          self._env[self._task_idx])
            total_paths += len(paths)

            for path in paths:
                p = {
                    'states': path['states'],
                    'actions': path['actions'],
                    'env_rewards': path['env_rewards'].reshape(-1, 1),
                    'skills_onehot': path['skills_onehot'],
                    'next_states': path['next_states'],
                    'dones': path['dones'].reshape(-1, 1)
                }
                self._replay_buffers[self._task_idx].add_path(p)

                if add_to_enc_buffer:
                    self._context_replay_buffers[self._task_idx].add_path(p)

            if update_posterior_rate != np.inf:
                context = self._sample_path_context(self._task_idx)
                self._controller.infer_posterior(context)

    def _sample_task_path(self, indices):
        if not hasattr(indices, '__iter__'):
            indices = [indices]

        initialized = False
        for idx in indices:
            path = self._context_replay_buffers[idx].sample_path()
            # should be replay_buffers[]
            # TODO: trim or extend batch to the same size

            context_o = path['states']
            context_a = path['actions']
            context_r = path['env_rewards']
            context_z = path['skills_onehot']
            context = np.hstack((np.hstack((np.hstack(
                (context_o, context_a)), context_r)), context_z))
            if self._use_next_obs_in_context:
                context = np.hstack((context, path['next_states']))

            if not initialized:
                final_context = context[np.newaxis]
                o = path['states'][np.newaxis]
                a = path['actions'][np.newaxis]
                r = path['env_rewards'][np.newaxis]
                z = path['skills_onehot'][np.newaxis]
                no = path['next_states'][np.newaxis]
                d = path['dones'][np.newaxis]
                initialized = True
            else:
                # print(o.shape)
                # print(path['states'].shape)
                o = np.vstack((o, path['states'][np.newaxis]))
                a = np.vstack((a, path['actions'][np.newaxis]))
                r = np.vstack((r, path['env_rewards'][np.newaxis]))
                z = np.vstack((z, path['skills_onehot'][np.newaxis]))
                no = np.vstack((no, path['next_states'][np.newaxis]))
                d = np.vstack((d, path['dones'][np.newaxis]))
                final_context = np.vstack((final_context, context[np.newaxis]))

        o = torch.as_tensor(o, device=tu.global_device()).float()
        a = torch.as_tensor(a, device=tu.global_device()).float()
        r = torch.as_tensor(r, device=tu.global_device()).float()
        z = torch.as_tensor(z, device=tu.global_device()).float()
        no = torch.as_tensor(no, device=tu.global_device()).float()
        d = torch.as_tensor(d, device=tu.global_device()).float()
        final_context = torch.as_tensor(final_context,
                                        device=tu.global_device()).float()
        if len(indices) == 1:
            final_context = final_context.unsqueeze(0)

        return o, a, r, z, no, d, final_context

    def _sample_skill_path(self):
        path = self._skills_replay_buffer.sample_path()
        # TODO: trim or extend batch to the same size
        o = path['states']
        a = path['actions']
        r = path['env_rewards']
        z = path['skills_onehot']
        context = np.hstack((np.hstack((np.hstack((o, a)), r)), z))
        if self._use_next_obs_in_context:
            context = np.hstack((context, path['next_states']))

        context = context[np.newaxis]
        o = path['states'][np.newaxis]
        a = path['actions'][np.newaxis]
        r = path['env_rewards'][np.newaxis]
        z = path['skills_onehot'][np.newaxis]
        no = path['next_states'][np.newaxis]
        d = path['dones'][np.newaxis]

        o = torch.as_tensor(o, device=tu.global_device()).float()
        a = torch.as_tensor(a, device=tu.global_device()).float()
        r = torch.as_tensor(r, device=tu.global_device()).float()
        z = torch.as_tensor(z, device=tu.global_device()).float()
        no = torch.as_tensor(no, device=tu.global_device()).float()
        d = torch.as_tensor(d, device=tu.global_device()).float()
        context = torch.as_tensor(context, device=tu.global_device()).float()
        context = context.unsqueeze(0)

        return o, a, r, z, no, d, context

    def _sample_path_context(self, indices):
        if not hasattr(indices, '__iter__'):
            indices = [indices]

        initialized = False
        for idx in indices:
            path = self._context_replay_buffers[idx].sample_path()
            o = path['states']
            a = path['actions']
            r = path['env_rewards']
            z = path['skills_onehot']
            context = np.hstack((np.hstack((np.hstack((o, a)), r)), z))
            if self._use_next_obs_in_context:
                context = np.hstack((context, path['states']))

            if not initialized:
                final_context = context[np.newaxis]
                initialized = True
            else:
                final_context = np.vstack((final_context, context[np.newaxis]))

        final_context = torch.as_tensor(final_context,
                                        device=tu.global_device()).float()
        if len(indices) == 1:
            final_context = final_context.unsqueeze(0)

        return final_context

    def _update_target_network(self):
        for target_param, param in zip(self.target_vf.parameters(),
                                       self._vf.parameters()):
            target_param.data.copy_(target_param.data *
                                    (1.0 - self._soft_target_tau) +
                                    param.data * self._soft_target_tau)

    def __getstate__(self):
        data = self.__dict__.copy()
        del data['_skills_replay_buffer']
        del data['_replay_buffers']
        del data['_context_replay_buffers']
        return data

    def __setstate__(self, state):
        self.__dict__.update(state)

        self._skills_replay_buffer = PathBuffer(self._replay_buffer_size)

        self._replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(self._num_train_tasks)
        }

        self._context_replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(self._num_train_tasks)
        }
        self._is_resuming = True

    def to(self, device=None):
        device = device or tu.global_device()
        for net in self.networks:
            # print(net)
            net.to(device)

    @classmethod
    def get_env_spec(cls, env_spec, latent_dim, num_skills, module):
        obs_dim = int(np.prod(env_spec.observation_space.shape))
        # print("obs_dim is")
        # print(obs_dim)
        action_dim = int(np.prod(env_spec.action_space.shape))
        if module == 'encoder':
            in_dim = obs_dim + action_dim + num_skills + 1
            out_dim = latent_dim * 2
        elif module == 'vf':
            in_dim = obs_dim
            out_dim = latent_dim
        elif module == 'controller_policy':
            in_dim = obs_dim + latent_dim
            out_dim = num_skills
        elif module == 'qf':
            in_dim = obs_dim + latent_dim
            out_dim = num_skills

        in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32)
        out_space = akro.Box(low=-1,
                             high=1,
                             shape=(out_dim, ),
                             dtype=np.float32)

        if module == 'encoder':
            spec = InOutSpec(in_space, out_space)
        elif module == 'vf':
            spec = EnvSpec(in_space, out_space)
        elif module == 'controller_policy':
            spec = EnvSpec(in_space, out_space)
        elif module == 'qf':
            spec = EnvSpec(in_space, out_space)
        return spec

    @property
    def policy(self):
        return self._controller

    @property
    def networks(self):
        return self._controller.networks + [self._controller] + [
            self._qf1, self._qf2, self._vf, self.target_vf
        ]

    def get_exploration_policy(self):
        return self._controller

    def adapt_policy(self, exploration_policy, exploration_trajectories):
        total_steps = sum(exploration_trajectories.lengths)
        o = exploration_trajectories.states
        a = exploration_trajectories.actions
        r = exploration_trajectories.env_rewards.reshape(total_steps, 1)
        s = exploration_trajectories.skills_onehot
        ctxt = np.hstack((o, a, r, s)).reshape(1, total_steps, -1)
        context = torch.as_tensor(ctxt, device=tu.global_device()).float()
        self._controller.infer_posterior(context)

        return self._controller
Example #17
0
    def __init__(
        self,
        env,
        skill_env,
        controller_policy,
        skill_actor,
        qf,
        vf,
        num_skills,
        num_train_tasks,
        num_test_tasks,
        latent_dim,
        encoder_hidden_sizes,
        test_env_sampler,
        sampler_class,  # to avoid cycling import
        controller_class=OpenContextConditionedControllerPolicy,
        encoder_class=GaussianContextEncoder,
        encoder_module_class=MLPEncoder,
        is_encoder_recurrent=False,
        controller_lr=3E-4,
        qf_lr=3E-4,
        vf_lr=3E-4,
        context_lr=3E-4,
        policy_mean_reg_coeff=1E-3,
        policy_std_reg_coeff=1E-3,
        policy_pre_activation_coeff=0.,
        soft_target_tau=0.005,
        kl_lambda=.1,
        optimizer_class=torch.optim.Adam,
        use_next_obs_in_context=False,
        meta_batch_size=64,
        num_steps_per_epoch=1000,
        num_skills_reason_steps=1000,
        num_skills_sample=10,
        num_initial_steps=1500,
        num_tasks_sample=5,
        num_steps_prior=400,
        num_steps_posterior=0,
        num_extra_rl_steps_posterior=600,
        batch_size=1024,
        embedding_batch_size=1024,
        embedding_mini_batch_size=1024,
        max_path_length=1000,
        discount=0.99,
        replay_buffer_size=1000000,
        # TODO: the ratio needs to be tuned
        skills_reason_reward_scale=1,
        tasks_adapt_reward_scale=1.2,
        use_information_bottleneck=True,
        update_post_train=1,
    ):

        self._env = env
        self._skill_env = skill_env
        self._qf1 = qf
        self._qf2 = copy.deepcopy(qf)
        self._vf = vf
        self._skill_actor = skill_actor
        self._num_skills = num_skills
        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks
        self._latent_dim = latent_dim

        self._policy_mean_reg_coeff = policy_mean_reg_coeff
        self._policy_std_reg_coeff = policy_std_reg_coeff
        self._policy_pre_activation_coeff = policy_pre_activation_coeff
        self._soft_target_tau = soft_target_tau
        self._kl_lambda = kl_lambda
        self._use_next_obs_in_context = use_next_obs_in_context

        self._meta_batch_size = meta_batch_size
        self._num_skills_reason_steps = num_skills_reason_steps
        self._num_steps_per_epoch = num_steps_per_epoch
        self._num_initial_steps = num_initial_steps
        self._num_tasks_sample = num_tasks_sample
        self._num_skills_sample = num_skills_sample
        self._num_steps_prior = num_steps_prior
        self._num_steps_posterior = num_steps_posterior
        self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior
        self._batch_size = batch_size
        self._embedding_batch_size = embedding_batch_size
        self._embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self._discount = discount
        self._replay_buffer_size = replay_buffer_size

        self._is_encoder_recurrent = is_encoder_recurrent
        self._use_information_bottleneck = use_information_bottleneck
        self._skills_reason_reward_scale = skills_reason_reward_scale
        self._tasks_adapt_reward_scale = tasks_adapt_reward_scale
        self._update_post_train = update_post_train
        self._task_idx = None
        self._skill_idx = None  # do we really need it

        self._is_resuming = False

        worker_args = dict(num_skills=num_skills,
                           skill_actor_class=type(skill_actor),
                           controller_class=controller_class,
                           deterministic=False,
                           accum_context=True)
        self._evaluator = MetaEvaluator(
            test_task_sampler=test_env_sampler,
            max_path_length=max_path_length,
            worker_class=KantWorker,
            worker_args=worker_args,
            n_test_tasks=num_test_tasks,
            sampler_class=sampler_class,
            trajectory_batch_class=SkillTrajectoryBatch)
        self._average_rewards = []

        encoder_spec = self.get_env_spec(env[0](), latent_dim, num_skills,
                                         'encoder')
        encoder_in_dim = int(np.prod(encoder_spec.input_space.shape))
        encoder_out_dim = int(np.prod(encoder_spec.output_space.shape))

        encoder_module = encoder_module_class(
            input_dim=encoder_in_dim,
            output_dim=encoder_out_dim,
            hidden_sizes=encoder_hidden_sizes)

        context_encoder = encoder_class(encoder_module,
                                        use_information_bottleneck, latent_dim)

        self._controller = controller_class(
            latent_dim=latent_dim,
            context_encoder=context_encoder,
            controller_policy=controller_policy,
            num_skills=num_skills,
            sub_actor=skill_actor,
            use_information_bottleneck=use_information_bottleneck,
            use_next_obs=use_next_obs_in_context)

        self._skills_replay_buffer = PathBuffer(replay_buffer_size)

        self._replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self._context_replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self.target_vf = copy.deepcopy(self._vf)
        self.vf_criterion = torch.nn.MSELoss()

        self._controller_optimizer = optimizer_class(
            self._controller.networks[1].parameters(),
            lr=controller_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self._qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self._qf2.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self._vf.parameters(),
            lr=vf_lr,
        )
        self.context_optimizer = optimizer_class(
            self._controller.networks[0].parameters(),
            lr=context_lr,
        )
Example #18
0
def torch_pearl_ml1_push(ctxt=None,
                         seed=1,
                         num_epochs=1000,
                         num_train_tasks=50,
                         num_test_tasks=10,
                         latent_size=7,
                         encoder_hidden_size=200,
                         net_size=300,
                         meta_batch_size=16,
                         num_steps_per_epoch=4000,
                         num_initial_steps=4000,
                         num_tasks_sample=15,
                         num_steps_prior=750,
                         num_extra_rl_steps_posterior=750,
                         batch_size=256,
                         embedding_batch_size=64,
                         embedding_mini_batch_size=64,
                         max_path_length=150,
                         reward_scale=10.,
                         use_gpu=False):
    """Train PEARL with ML1 environments.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        num_epochs (int): Number of training epochs.
        num_train_tasks (int): Number of tasks for training.
        num_test_tasks (int): Number of tasks for testing.
        latent_size (int): Size of latent context vector.
        encoder_hidden_size (int): Output dimension of dense layer of the
            context encoder.
        net_size (int): Output dimension of a dense layer of Q-function and
            value function.
        meta_batch_size (int): Meta batch size.
        num_steps_per_epoch (int): Number of iterations per epoch.
        num_initial_steps (int): Number of transitions obtained per task before
            training.
        num_tasks_sample (int): Number of random tasks to obtain data for each
            iteration.
        num_steps_prior (int): Number of transitions to obtain per task with
            z ~ prior.
        num_extra_rl_steps_posterior (int): Number of additional transitions
            to obtain per task with z ~ posterior that are only used to train
            the policy and NOT the encoder.
        batch_size (int): Number of transitions in RL batch.
        embedding_batch_size (int): Number of transitions in context batch.
        embedding_mini_batch_size (int): Number of transitions in mini context
            batch; should be same as embedding_batch_size for non-recurrent
            encoder.
        max_path_length (int): Maximum path length.
        reward_scale (int): Reward scale.
        use_gpu (bool): Whether or not to use GPU for training.

    """
    set_seed(seed)
    encoder_hidden_sizes = (encoder_hidden_size, encoder_hidden_size,
                            encoder_hidden_size)
    # create multi-task environment and sample tasks
    env_sampler = SetTaskSampler(
        lambda: GarageEnv(normalize(ML1.get_train_tasks('push-v1'))))
    env = env_sampler.sample(num_train_tasks)

    test_env_sampler = SetTaskSampler(
        lambda: GarageEnv(normalize(ML1.get_test_tasks('push-v1'))))

    runner = LocalRunner(ctxt)

    # instantiate networks
    augmented_env = PEARL.augment_env_spec(env[0](), latent_size)
    qf = ContinuousMLPQFunction(env_spec=augmented_env,
                                hidden_sizes=[net_size, net_size, net_size])

    vf_env = PEARL.get_env_spec(env[0](), latent_size, 'vf')
    vf = ContinuousMLPQFunction(env_spec=vf_env,
                                hidden_sizes=[net_size, net_size, net_size])

    inner_policy = TanhGaussianMLPPolicy(
        env_spec=augmented_env, hidden_sizes=[net_size, net_size, net_size])

    pearl = PEARL(
        env=env,
        policy_class=ContextConditionedPolicy,
        encoder_class=MLPEncoder,
        inner_policy=inner_policy,
        qf=qf,
        vf=vf,
        num_train_tasks=num_train_tasks,
        num_test_tasks=num_test_tasks,
        latent_dim=latent_size,
        encoder_hidden_sizes=encoder_hidden_sizes,
        meta_batch_size=meta_batch_size,
        num_steps_per_epoch=num_steps_per_epoch,
        num_initial_steps=num_initial_steps,
        num_tasks_sample=num_tasks_sample,
        num_steps_prior=num_steps_prior,
        num_extra_rl_steps_posterior=num_extra_rl_steps_posterior,
        batch_size=batch_size,
        embedding_batch_size=embedding_batch_size,
        embedding_mini_batch_size=embedding_mini_batch_size,
        max_path_length=max_path_length,
        reward_scale=reward_scale,
    )

    tu.set_gpu_mode(use_gpu, gpu_id=0)
    if use_gpu:
        pearl.to()

    runner.setup(algo=pearl,
                 env=env[0](),
                 sampler_cls=LocalSampler,
                 sampler_args=dict(max_path_length=max_path_length),
                 n_workers=1,
                 worker_class=PEARLWorker)

    worker_args = dict(deterministic=True, accum_context=True)
    meta_evaluator = MetaEvaluator(test_task_sampler=test_env_sampler,
                                   max_path_length=max_path_length,
                                   worker_class=PEARLWorker,
                                   worker_args=worker_args,
                                   n_test_tasks=num_test_tasks)
    pearl.evaluator = meta_evaluator
    runner.train(n_epochs=num_epochs, batch_size=batch_size)
Example #19
0
class MetaBasicHierch(MetaRLAlgorithm):
    def __init__(
        self,
        env,
        skill_env,
        controller_policy,
        skill_actor,
        qf,
        vf,
        num_skills,
        num_train_tasks,
        num_test_tasks,
        test_env_sampler,
        sampler_class,  # to avoid cycling import
        controller_class=ControllerPolicy,
        controller_lr=3E-4,
        qf_lr=3E-4,
        vf_lr=3E-4,
        policy_mean_reg_coeff=1E-3,
        policy_std_reg_coeff=1E-3,
        policy_pre_activation_coeff=0.,
        soft_target_tau=0.005,
        optimizer_class=torch.optim.Adam,
        meta_batch_size=64,
        num_steps_per_epoch=1000,
        num_tasks_sample=5,
        batch_size=1024,
        embedding_batch_size=1024,
        embedding_mini_batch_size=1024,
        max_path_length=1000,
        discount=0.99,
        replay_buffer_size=1000000,
    ):

        self._env = env
        self._skill_env = skill_env
        self._qf1 = qf
        self._qf2 = copy.deepcopy(qf)
        self._vf = vf
        self._skill_actor = skill_actor
        self._num_skills = num_skills
        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks

        self._policy_mean_reg_coeff = policy_mean_reg_coeff
        self._policy_std_reg_coeff = policy_std_reg_coeff
        self._policy_pre_activation_coeff = policy_pre_activation_coeff
        self._soft_target_tau = soft_target_tau

        self._meta_batch_size = meta_batch_size
        self._num_steps_per_epoch = num_steps_per_epoch
        self._num_tasks_sample = num_tasks_sample
        self._batch_size = batch_size
        self._embedding_batch_size = embedding_batch_size
        self._embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self._discount = discount
        self._replay_buffer_size = replay_buffer_size

        self._task_idx = None
        self._skill_idx = None  # do we really need it

        self._is_resuming = False

        worker_args = dict(num_skills=num_skills,
                           skill_actor_class=type(skill_actor),
                           controller_class=controller_class,
                           deterministic=False,
                           accum_context=False)
        self._evaluator = MetaEvaluator(
            test_task_sampler=test_env_sampler,
            max_path_length=max_path_length,
            worker_class=BasicHierachWorker,
            worker_args=worker_args,
            n_test_tasks=num_test_tasks,
            sampler_class=sampler_class,
            trajectory_batch_class=SkillTrajectoryBatch)
        self._average_rewards = []

        self._controller = controller_class(
            controller_policy=controller_policy,
            num_skills=num_skills,
            sub_actor=skill_actor)

        self._skills_replay_buffer = PathBuffer(replay_buffer_size)

        self._replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self.target_vf = copy.deepcopy(self._vf)
        self.vf_criterion = torch.nn.MSELoss()

        self._controller_optimizer = optimizer_class(
            self._controller.networks[1].parameters(),
            lr=controller_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self._qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self._qf2.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self._vf.parameters(),
            lr=vf_lr,
        )

    def train(self, runner):
        for _ in runner.step_epochs():
            epoch = runner.step_itr / self._num_steps_per_epoch

            if epoch == 0 or self._is_resuming:
                for idx in range(self._num_train_tasks):
                    self._task_idx = idx
                    self._obtain_task_samples(runner, epoch,
                                              self._num_initial_steps, np.inf)
                self._is_resuming = False

            logger.log('Sampling tasks')
            for _ in range(self._num_tasks_sample):
                idx = np.random.randint(self._num_train_tasks)
                self._task_idx = idx
                self._replay_buffers[idx].clear()
                self._obtain_task_samples(runner, epoch,
                                          self._num_steps_per_epoch)

            logger.log('Training task adapting...')
            self._tasks_adapt_train_once()

            runner.step_itr += 1

            logger.log('Evaluating...')
            # evaluate
            self._controller.reset_belief()
            self._average_rewards.append(self._evaluator.evaluate(self))

        return self._average_rewards

    def _tasks_adapt_train_once(self):
        for _ in range(self._num_steps_per_epoch):
            indices = np.random.choice(range(self._num_train_tasks),
                                       self._meta_batch_size)
            self._tasks_adapt_optimize_policy(indices)

    def _tasks_adapt_optimize_policy(self, indices):
        num_tasks = len(indices)
        obs, actions, rewards, skills, next_obs, terms = self._sample_task_path(
            indices)
        self._controller.reset_belief(num_tasks=num_tasks)

        # data shape is (task, batch, feat)
        # new_skills_pred is distribution
        policy_outputs, new_skills_pred, task_z = self._controller(obs)
        new_actions, policy_mean, policy_log_std, policy_log_pi = policy_outputs[:
                                                                                 4]

        # flatten out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        # actions = actions.view(t * b, -1)
        skills = skills.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)

        # optimize qf and encoder networks
        # TODO try [obs, skills, actions] or [obs, skills, task_z]
        # FIXME prob need to reshape or tile task_z
        obs = obs.to(tu.global_device())
        skills = skills.to(tu.global_device())
        next_obs = next_obs.to(tu.global_device())

        q1_pred = self._qf1(torch.cat([obs, skills], dim=1))
        q2_pred = self._qf2(torch.cat([obs, skills], dim=1))

        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()

        rewards_flat = rewards.view(b * t, -1)
        terms_flat = terms.view(b * t, -1)
        q_target = rewards_flat + (1. - terms_flat) * self._discount
        qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean(
            (q2_pred - q_target)**2)
        qf_loss.backward()

        self.qf1_optimizer.step()
        self.qf2_optimizer.step()

        new_skills_pred = new_skills_pred.to(tu.global_device())

        # compute min Q on the new actions
        q1 = self._qf1(torch.cat([obs, new_skills_pred], dim=1),
                       task_z.detach())
        q2 = self._qf2(torch.cat([obs, new_skills_pred], dim=1),
                       task_z.detach())
        min_q = torch.min(q1, q2)

        # optimize vf
        policy_log_pi = policy_log_pi.to(tu.global_device())

        # optimize policy
        log_policy_target = min_q
        policy_loss = (policy_log_pi - log_policy_target).mean()

        mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean()
        std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean()

        # took away pre-activation reg
        policy_reg_loss = mean_reg_loss + std_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        self._controller_optimizer.zero_grad()
        policy_loss.backward()
        self._controller_optimizer.step()

    def _obtain_task_samples(self, runner, itr, num_paths):
        self._controller.reset_belief()
        total_paths = 0
        num_paths_per_batch = num_paths

        while total_paths < num_paths:
            num_samples = num_paths_per_batch * self.max_path_length
            paths = runner.obtain_samples(itr, num_samples, self._controller,
                                          self._env[self._task_idx])
            total_paths += len(paths)

            for path in paths:
                p = {
                    'states': path['states'],
                    'actions': path['actions'],
                    'env_rewards': path['env_rewards'].reshape(-1, 1),
                    'skills_onehot': path['skills_onehot'],
                    'next_states': path['next_states'],
                    'dones': path['dones'].reshape(-1, 1)
                }
                self._replay_buffers[self._task_idx].add_path(p)

    def _sample_task_path(self, indices):
        if not hasattr(indices, '__iter__'):
            indices = [indices]

        initialized = False
        for idx in indices:
            path = self._replay_buffers[idx].sample_path()
            # TODO: trim or extend batch to the same size
            if not initialized:
                o = path['states'][np.newaxis]
                a = path['actions'][np.newaxis]
                r = path['env_rewards'][np.newaxis]
                z = path['skills_onehot'][np.newaxis]
                no = path['next_states'][np.newaxis]
                d = path['dones'][np.newaxis]
                initialized = True
            else:
                o = np.vstack((o, path['states'][np.newaxis]))
                a = np.vstack((a, path['actions'][np.newaxis]))
                r = np.vstack((r, path['env_rewards'][np.newaxis]))
                z = np.vstack((z, path['skills_onehot'][np.newaxis]))
                no = np.vstack((no, path['next_states'][np.newaxis]))
                d = np.vstack((d, path['dones'][np.newaxis]))

        o = torch.as_tensor(o, device=tu.global_device()).float()
        a = torch.as_tensor(a, device=tu.global_device()).float()
        r = torch.as_tensor(r, device=tu.global_device()).float()
        z = torch.as_tensor(z, device=tu.global_device()).float()
        no = torch.as_tensor(no, device=tu.global_device()).float()
        d = torch.as_tensor(d, device=tu.global_device()).float()

        return o, a, r, z, no, d

    def _update_target_network(self):
        for target_param, param in zip(self.target_vf.parameters(),
                                       self._vf.parameters()):
            target_param.data.copy_(target_param.data *
                                    (1.0 - self._soft_target_tau) +
                                    param.data * self._soft_target_tau)

    def __getstate__(self):
        data = self.__dict__.copy()
        del data['_replay_buffers']
        return data

    def __setstate__(self, state):
        self.__dict__.update(state)

        self._replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(self._num_train_tasks)
        }

        self._is_resuming = True

    def to(self, device=None):
        device = device or tu.global_device()
        for net in self.networks:
            # print(net)
            net.to(device)

    @classmethod
    def get_env_spec(cls, env_spec, num_skills, module):
        obs_dim = int(np.prod(env_spec.observation_space.shape))
        action_dim = int(np.prod(env_spec.action_space.shape))
        if module == 'controller_policy':
            in_dim = obs_dim
            out_dim = num_skills
        if module == 'qf':
            in_dim = obs_dim
            out_dim = num_skills

        in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32)
        out_space = akro.Box(low=-1,
                             high=1,
                             shape=(out_dim, ),
                             dtype=np.float32)

        if module == 'controller_policy':
            spec = EnvSpec(in_space, out_space)
        if module == 'qf':
            spec = EnvSpec(in_space, out_space)
        return spec

    @property
    def policy(self):
        return self._controller

    @property
    def networks(self):
        return self._controller.networks + [self._controller] + [
            self._qf1, self._qf2, self._vf, self.target_vf
        ]

    def get_exploration_policy(self):
        return self._controller

    def adapt_policy(self, exploration_policy, exploration_trajectories):
        total_steps = sum(exploration_trajectories.lengths)
        o = exploration_trajectories.states
        a = exploration_trajectories.actions
        r = exploration_trajectories.env_rewards.reshape(total_steps, 1)
        s = exploration_trajectories.skills_onehot

        return self._controller
Example #20
0
class PEARL(MetaRLAlgorithm):
    r"""A PEARL model based on https://arxiv.org/abs/1903.08254.

    PEARL, which stands for Probablistic Embeddings for Actor-Critic
    Reinforcement Learning, is an off-policy meta-RL algorithm. It is built
    on top of SAC using two Q-functions and a value function with an addition
    of an inference network that estimates the posterior :math:`q(z \| c)`.
    The policy is conditioned on the latent variable Z in order to adpat its
    behavior to specific tasks.

    Args:
        env (list[GarageEnv]): Batch of sampled environment updates(EnvUpdate),
            which, when invoked on environments, will configure them with new
            tasks.
        policy_class (garage.torch.policies.Policy): Context-conditioned policy
            class.
        encoder_class (garage.torch.embeddings.ContextEncoder): Encoder class
            for the encoder in context-conditioned policy.
        inner_policy (garage.torch.policies.Policy): Policy.
        qf (torch.nn.Module): Q-function.
        vf (torch.nn.Module): Value function.
        num_train_tasks (int): Number of tasks for training.
        num_test_tasks (int): Number of tasks for testing.
        latent_dim (int): Size of latent context vector.
        encoder_hidden_sizes (list[int]): Output dimension of dense layer(s) of
            the context encoder.
        test_env_sampler (garage.experiment.SetTaskSampler): Sampler for test
            tasks.
        policy_lr (float): Policy learning rate.
        qf_lr (float): Q-function learning rate.
        vf_lr (float): Value function learning rate.
        context_lr (float): Inference network learning rate.
        policy_mean_reg_coeff (float): Policy mean regulation weight.
        policy_std_reg_coeff (float): Policy std regulation weight.
        policy_pre_activation_coeff (float): Policy pre-activation weight.
        soft_target_tau (float): Interpolation parameter for doing the
            soft target update.
        kl_lambda (float): KL lambda value.
        optimizer_class (callable): Type of optimizer for training networks.
        use_information_bottleneck (bool): False means latent context is
            deterministic.
        use_next_obs_in_context (bool): Whether or not to use next observation
            in distinguishing between tasks.
        meta_batch_size (int): Meta batch size.
        num_steps_per_epoch (int): Number of iterations per epoch.
        num_initial_steps (int): Number of transitions obtained per task before
            training.
        num_tasks_sample (int): Number of random tasks to obtain data for each
            iteration.
        num_steps_prior (int): Number of transitions to obtain per task with
            z ~ prior.
        num_steps_posterior (int): Number of transitions to obtain per task
            with z ~ posterior.
        num_extra_rl_steps_posterior (int): Number of additional transitions
            to obtain per task with z ~ posterior that are only used to train
            the policy and NOT the encoder.
        batch_size (int): Number of transitions in RL batch.
        embedding_batch_size (int): Number of transitions in context batch.
        embedding_mini_batch_size (int): Number of transitions in mini context
            batch; should be same as embedding_batch_size for non-recurrent
            encoder.
        max_path_length (int): Maximum path length.
        discount (float): RL discount factor.
        replay_buffer_size (int): Maximum samples in replay buffer.
        reward_scale (int): Reward scale.
        update_post_train (int): How often to resample context when obtaining
            data during training (in trajectories).

    """

    # pylint: disable=too-many-statements
    def __init__(self,
                 env,
                 inner_policy,
                 qf,
                 vf,
                 num_train_tasks,
                 num_test_tasks,
                 latent_dim,
                 encoder_hidden_sizes,
                 test_env_sampler,
                 policy_class=ContextConditionedPolicy,
                 encoder_class=MLPEncoder,
                 policy_lr=3E-4,
                 qf_lr=3E-4,
                 vf_lr=3E-4,
                 context_lr=3E-4,
                 policy_mean_reg_coeff=1E-3,
                 policy_std_reg_coeff=1E-3,
                 policy_pre_activation_coeff=0.,
                 soft_target_tau=0.005,
                 kl_lambda=.1,
                 optimizer_class=torch.optim.Adam,
                 use_information_bottleneck=True,
                 use_next_obs_in_context=False,
                 meta_batch_size=64,
                 num_steps_per_epoch=1000,
                 num_initial_steps=100,
                 num_tasks_sample=100,
                 num_steps_prior=100,
                 num_steps_posterior=0,
                 num_extra_rl_steps_posterior=100,
                 batch_size=1024,
                 embedding_batch_size=1024,
                 embedding_mini_batch_size=1024,
                 max_path_length=1000,
                 discount=0.99,
                 replay_buffer_size=1000000,
                 reward_scale=1,
                 update_post_train=1):

        self._env = env
        self._qf1 = qf
        self._qf2 = copy.deepcopy(qf)
        self._vf = vf
        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks
        self._latent_dim = latent_dim

        self._policy_mean_reg_coeff = policy_mean_reg_coeff
        self._policy_std_reg_coeff = policy_std_reg_coeff
        self._policy_pre_activation_coeff = policy_pre_activation_coeff
        self._soft_target_tau = soft_target_tau
        self._kl_lambda = kl_lambda
        self._use_information_bottleneck = use_information_bottleneck
        self._use_next_obs_in_context = use_next_obs_in_context

        self._meta_batch_size = meta_batch_size
        self._num_steps_per_epoch = num_steps_per_epoch
        self._num_initial_steps = num_initial_steps
        self._num_tasks_sample = num_tasks_sample
        self._num_steps_prior = num_steps_prior
        self._num_steps_posterior = num_steps_posterior
        self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior
        self._batch_size = batch_size
        self._embedding_batch_size = embedding_batch_size
        self._embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self._discount = discount
        self._replay_buffer_size = replay_buffer_size
        self._reward_scale = reward_scale
        self._update_post_train = update_post_train
        self._task_idx = None

        self._is_resuming = False

        worker_args = dict(deterministic=True, accum_context=True)
        self._evaluator = MetaEvaluator(test_task_sampler=test_env_sampler,
                                        max_path_length=max_path_length,
                                        worker_class=PEARLWorker,
                                        worker_args=worker_args,
                                        n_test_tasks=num_test_tasks)

        encoder_spec = self.get_env_spec(env[0](), latent_dim, 'encoder')
        encoder_in_dim = int(np.prod(encoder_spec.input_space.shape))
        encoder_out_dim = int(np.prod(encoder_spec.output_space.shape))
        context_encoder = encoder_class(input_dim=encoder_in_dim,
                                        output_dim=encoder_out_dim,
                                        hidden_sizes=encoder_hidden_sizes)

        self._policy = policy_class(
            latent_dim=latent_dim,
            context_encoder=context_encoder,
            policy=inner_policy,
            use_information_bottleneck=use_information_bottleneck,
            use_next_obs=use_next_obs_in_context)

        # buffer for training RL update
        self._replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self._context_replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self.target_vf = copy.deepcopy(self._vf)
        self.vf_criterion = torch.nn.MSELoss()

        self._policy_optimizer = optimizer_class(
            self._policy.networks[1].parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self._qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self._qf2.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self._vf.parameters(),
            lr=vf_lr,
        )
        self.context_optimizer = optimizer_class(
            self._policy.networks[0].parameters(),
            lr=context_lr,
        )

    def __getstate__(self):
        """Object.__getstate__.

        Returns:
            dict: the state to be pickled for the instance.

        """
        data = self.__dict__.copy()
        del data['_replay_buffers']
        del data['_context_replay_buffers']
        return data

    def __setstate__(self, state):
        """Object.__setstate__.

        Args:
            state (dict): unpickled state.

        """
        self.__dict__.update(state)
        self._replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(self._num_train_tasks)
        }

        self._context_replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(self._num_train_tasks)
        }
        self._is_resuming = True

    def train(self, runner):
        """Obtain samples, train, and evaluate for each epoch.

        Args:
            runner (LocalRunner): LocalRunner is passed to give algorithm
                the access to runner.step_epochs(), which provides services
                such as snapshotting and sampler control.

        """
        for _ in runner.step_epochs():
            epoch = runner.step_itr / self._num_steps_per_epoch

            # obtain initial set of samples from all train tasks
            if epoch == 0 or self._is_resuming:
                for idx in range(self._num_train_tasks):
                    self._task_idx = idx
                    self._obtain_samples(runner, epoch,
                                         self._num_initial_steps, np.inf)
                    self._is_resuming = False

            # obtain samples from random tasks
            for _ in range(self._num_tasks_sample):
                idx = np.random.randint(self._num_train_tasks)
                self._task_idx = idx
                self._context_replay_buffers[idx].clear()
                # obtain samples with z ~ prior
                if self._num_steps_prior > 0:
                    self._obtain_samples(runner, epoch, self._num_steps_prior,
                                         np.inf)
                # obtain samples with z ~ posterior
                if self._num_steps_posterior > 0:
                    self._obtain_samples(runner, epoch,
                                         self._num_steps_posterior,
                                         self._update_post_train)
                # obtain extras samples for RL training but not encoder
                if self._num_extra_rl_steps_posterior > 0:
                    self._obtain_samples(runner,
                                         epoch,
                                         self._num_extra_rl_steps_posterior,
                                         self._update_post_train,
                                         add_to_enc_buffer=False)

            logger.log('Training...')
            # sample train tasks and optimize networks
            self._train_once()
            runner.step_itr += 1

            logger.log('Evaluating...')
            # evaluate
            self._policy.reset_belief()
            self._evaluator.evaluate(self)

    def _train_once(self):
        """Perform one iteration of training."""
        for _ in range(self._num_steps_per_epoch):
            indices = np.random.choice(range(self._num_train_tasks),
                                       self._meta_batch_size)
            self._optimize_policy(indices)

    def _optimize_policy(self, indices):
        """Perform algorithm optimizing.

        Args:
            indices (list): Tasks used for training.

        """
        num_tasks = len(indices)
        context = self._sample_context(indices)
        # clear context and reset belief of policy
        self._policy.reset_belief(num_tasks=num_tasks)

        # data shape is (task, batch, feat)
        obs, actions, rewards, next_obs, terms = self._sample_data(indices)
        policy_outputs, task_z = self._policy(obs, context)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]

        # flatten out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)

        # optimize qf and encoder networks
        q1_pred = self._qf1(torch.cat([obs, actions], dim=1), task_z)
        q2_pred = self._qf2(torch.cat([obs, actions], dim=1), task_z)
        v_pred = self._vf(obs, task_z.detach())

        with torch.no_grad():
            target_v_values = self.target_vf(next_obs, task_z)

        # KL constraint on z if probabilistic
        self.context_optimizer.zero_grad()
        if self._use_information_bottleneck:
            kl_div = self._policy.compute_kl_div()
            kl_loss = self._kl_lambda * kl_div
            kl_loss.backward(retain_graph=True)

        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()

        rewards_flat = rewards.view(self._batch_size * num_tasks, -1)
        rewards_flat = rewards_flat * self._reward_scale
        terms_flat = terms.view(self._batch_size * num_tasks, -1)
        q_target = rewards_flat + (
            1. - terms_flat) * self._discount * target_v_values
        qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean(
            (q2_pred - q_target)**2)
        qf_loss.backward()

        self.qf1_optimizer.step()
        self.qf2_optimizer.step()
        self.context_optimizer.step()

        # compute min Q on the new actions
        q1 = self._qf1(torch.cat([obs, new_actions], dim=1), task_z.detach())
        q2 = self._qf2(torch.cat([obs, new_actions], dim=1), task_z.detach())
        min_q = torch.min(q1, q2)

        # optimize vf
        v_target = min_q - log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()
        self._update_target_network()

        # optimize policy
        log_policy_target = min_q
        policy_loss = (log_pi - log_policy_target).mean()

        mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean()
        std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self._policy_pre_activation_coeff * (
            (pre_tanh_value**2).sum(dim=1).mean())
        policy_reg_loss = (mean_reg_loss + std_reg_loss +
                           pre_activation_reg_loss)
        policy_loss = policy_loss + policy_reg_loss

        self._policy_optimizer.zero_grad()
        policy_loss.backward()
        self._policy_optimizer.step()

    def _obtain_samples(self,
                        runner,
                        itr,
                        num_samples,
                        update_posterior_rate,
                        add_to_enc_buffer=True):
        """Obtain samples.

        Args:
            runner (LocalRunner): LocalRunner.
            itr (int): Index of iteration (epoch).
            num_samples (int): Number of samples to obtain.
            update_posterior_rate (int): How often (in trajectories) to infer
                posterior of policy.
            add_to_enc_buffer (bool): Whether or not to add samples to encoder
                buffer.

        """
        self._policy.reset_belief()
        total_samples = 0

        if update_posterior_rate != np.inf:
            num_samples_per_batch = (update_posterior_rate *
                                     self.max_path_length)
        else:
            num_samples_per_batch = num_samples

        while total_samples < num_samples:
            paths = runner.obtain_samples(itr, num_samples_per_batch,
                                          self._policy,
                                          self._env[self._task_idx])
            total_samples += sum([len(path['rewards']) for path in paths])

            for path in paths:
                p = {
                    'observations': path['observations'],
                    'actions': path['actions'],
                    'rewards': path['rewards'].reshape(-1, 1),
                    'next_observations': path['next_observations'],
                    'dones': path['dones'].reshape(-1, 1)
                }
                self._replay_buffers[self._task_idx].add_path(p)

                if add_to_enc_buffer:
                    self._context_replay_buffers[self._task_idx].add_path(p)

            if update_posterior_rate != np.inf:
                context = self._sample_context(self._task_idx)
                self._policy.infer_posterior(context)

    def _sample_data(self, indices):
        """Sample batch of training data from a list of tasks.

        Args:
            indices (list): List of task indices to sample from.

        Returns:
            torch.Tensor: Obervations, with shape :math:`(X, N, O^*)` where X
                is the number of tasks. N is batch size.
            torch.Tensor: Actions, with shape :math:`(X, N, A^*)`.
            torch.Tensor: Rewards, with shape :math:`(X, N, 1)`.
            torch.Tensor: Next obervations, with shape :math:`(X, N, O^*)`.
            torch.Tensor: Dones, with shape :math:`(X, N, 1)`.

        """
        # transitions sampled randomly from replay buffer
        initialized = False
        for idx in indices:
            batch = self._replay_buffers[idx].sample_transitions(
                self._batch_size)
            if not initialized:
                o = batch['observations'][np.newaxis]
                a = batch['actions'][np.newaxis]
                r = batch['rewards'][np.newaxis]
                no = batch['next_observations'][np.newaxis]
                d = batch['dones'][np.newaxis]
                initialized = True
            else:
                o = np.vstack((o, batch['observations'][np.newaxis]))
                a = np.vstack((a, batch['actions'][np.newaxis]))
                r = np.vstack((r, batch['rewards'][np.newaxis]))
                no = np.vstack((no, batch['next_observations'][np.newaxis]))
                d = np.vstack((d, batch['dones'][np.newaxis]))

        o = torch.as_tensor(o, device=tu.global_device()).float()
        a = torch.as_tensor(a, device=tu.global_device()).float()
        r = torch.as_tensor(r, device=tu.global_device()).float()
        no = torch.as_tensor(no, device=tu.global_device()).float()
        d = torch.as_tensor(d, device=tu.global_device()).float()

        return o, a, r, no, d

    def _sample_context(self, indices):
        """Sample batch of context from a list of tasks.

        Args:
            indices (list): List of task indices to sample from.

        Returns:
            torch.Tensor: Context data, with shape :math:`(X, N, C)`. X is the
                number of tasks. N is batch size. C is the combined size of
                observation, action, reward, and next observation if next
                observation is used in context. Otherwise, C is the combined
                size of observation, action, and reward.

        """
        # make method work given a single task index
        if not hasattr(indices, '__iter__'):
            indices = [indices]

        initialized = False
        for idx in indices:
            batch = self._context_replay_buffers[idx].sample_transitions(
                self._embedding_batch_size)
            o = batch['observations']
            a = batch['actions']
            r = batch['rewards']
            context = np.hstack((np.hstack((o, a)), r))
            if self._use_next_obs_in_context:
                context = np.hstack((context, batch['next_observations']))

            if not initialized:
                final_context = context[np.newaxis]
                initialized = True
            else:
                final_context = np.vstack((final_context, context[np.newaxis]))

        final_context = torch.as_tensor(final_context,
                                        device=tu.global_device()).float()
        if len(indices) == 1:
            final_context = final_context.unsqueeze(0)

        return final_context

    def _update_target_network(self):
        """Update parameters in the target vf network."""
        for target_param, param in zip(self.target_vf.parameters(),
                                       self._vf.parameters()):
            target_param.data.copy_(target_param.data *
                                    (1.0 - self._soft_target_tau) +
                                    param.data * self._soft_target_tau)

    @property
    def policy(self):
        """Return all the policy within the model.

        Returns:
            garage.torch.policies.Policy: Policy within the model.

        """
        return self._policy

    @property
    def networks(self):
        """Return all the networks within the model.

        Returns:
            list: A list of networks.

        """
        return self._policy.networks + [self._policy] + [
            self._qf1, self._qf2, self._vf, self.target_vf
        ]

    def get_exploration_policy(self):
        """Return a policy used before adaptation to a specific task.

        Each time it is retrieved, this policy should only be evaluated in one
        task.

        Returns:
            garage.Policy: The policy used to obtain samples that are later
                used for meta-RL adaptation.

        """
        return self._policy

    def adapt_policy(self, exploration_policy, exploration_trajectories):
        """Produce a policy adapted for a task.

        Args:
            exploration_policy (garage.Policy): A policy which was returned
                from get_exploration_policy(), and which generated
                exploration_trajectories by interacting with an environment.
                The caller may not use this object after passing it into this
                method.
            exploration_trajectories (garage.TrajectoryBatch): Trajectories to
                adapt to, generated by exploration_policy exploring the
                environment.

        Returns:
            garage.Policy: A policy adapted to the task represented by the
                exploration_trajectories.

        """
        total_steps = sum(exploration_trajectories.lengths)
        o = exploration_trajectories.observations
        a = exploration_trajectories.actions
        r = exploration_trajectories.rewards.reshape(total_steps, 1)
        ctxt = np.hstack((o, a, r)).reshape(1, total_steps, -1)
        context = torch.as_tensor(ctxt, device=tu.global_device()).float()
        self._policy.infer_posterior(context)

        return self._policy

    def to(self, device=None):
        """Put all the networks within the model on device.

        Args:
            device (str): ID of GPU or CPU.

        """
        device = device or tu.global_device()
        for net in self.networks:
            net.to(device)

    @classmethod
    def augment_env_spec(cls, env_spec, latent_dim):
        """Augment environment by a size of latent dimension.

        Args:
            env_spec (garage.envs.EnvSpec): Environment specs to be augmented.
            latent_dim (int): Latent dimension.

        Returns:
            garage.envs.EnvSpec: Augmented environment specs.

        """
        obs_dim = int(np.prod(env_spec.observation_space.shape))
        action_dim = int(np.prod(env_spec.action_space.shape))
        aug_obs = akro.Box(low=-1,
                           high=1,
                           shape=(obs_dim + latent_dim, ),
                           dtype=np.float32)
        aug_act = akro.Box(low=-1,
                           high=1,
                           shape=(action_dim, ),
                           dtype=np.float32)
        return EnvSpec(aug_obs, aug_act)

    @classmethod
    def get_env_spec(cls, env_spec, latent_dim, module):
        """Get environment specs of encoder with latent dimension.

        Args:
            env_spec (garage.envs.EnvSpec): Environment specs.
            latent_dim (int): Latent dimension.
            module (str): Module to get environment specs for.

        Returns:
            garage.envs.InOutSpec: Module environment specs with latent
                dimension.

        """
        obs_dim = int(np.prod(env_spec.observation_space.shape))
        action_dim = int(np.prod(env_spec.action_space.shape))
        if module == 'encoder':
            in_dim = obs_dim + action_dim + 1
            out_dim = latent_dim * 2
        elif module == 'vf':
            in_dim = obs_dim
            out_dim = latent_dim
        in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32)
        out_space = akro.Box(low=-1,
                             high=1,
                             shape=(out_dim, ),
                             dtype=np.float32)
        if module == 'encoder':
            spec = InOutSpec(in_space, out_space)
        elif module == 'vf':
            spec = EnvSpec(in_space, out_space)

        return spec
Example #21
0
    def __init__(self,
                 env,
                 inner_policy,
                 qf1,
                 qf2,
                 sampler,
                 num_train_tasks,
                 num_test_tasks,
                 latent_dim,
                 encoder_hidden_sizes,
                 test_env_sampler,
                 policy_class=ContextConditionedPolicy,
                 encoder_class=MLPEncoder,
                 policy_lr=3E-4,
                 qf_lr=3E-4,
                 context_lr=3E-4,
                 policy_mean_reg_coeff=1E-3,
                 policy_std_reg_coeff=1E-3,
                 policy_pre_activation_coeff=0.,
                 soft_target_tau=0.005,
                 kl_lambda=.1,
                 fixed_alpha=None,
                 target_entropy=None,
                 initial_log_entropy=0.,
                 optimizer_class=torch.optim.Adam,
                 use_information_bottleneck=True,
                 use_next_obs_in_context=False,
                 meta_batch_size=64,
                 num_steps_per_epoch=1000,
                 num_initial_steps=100,
                 num_tasks_sample=100,
                 num_steps_prior=100,
                 num_steps_posterior=0,
                 num_extra_rl_steps_posterior=100,
                 batch_size=1024,
                 embedding_batch_size=1024,
                 embedding_mini_batch_size=1024,
                 max_path_length=500,
                 discount=0.99,
                 replay_buffer_size=1000000,
                 single_alpha=False,
                 reward_scale=1,
                 update_post_train=1):

        self._env = env
        self._qf1 = qf1
        self._qf2 = qf2
        # use 2 target q networks
        self._target_qf1 = copy.deepcopy(self._qf1)
        self._target_qf2 = copy.deepcopy(self._qf2)

        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks
        self._latent_dim = latent_dim

        self._policy_lr = policy_lr
        self._qf_lr = qf_lr
        self._context_lr = context_lr
        self._policy_mean_reg_coeff = policy_mean_reg_coeff
        self._policy_std_reg_coeff = policy_std_reg_coeff
        self._policy_pre_activation_coeff = policy_pre_activation_coeff
        self._soft_target_tau = soft_target_tau
        self._kl_lambda = kl_lambda
        self._use_information_bottleneck = use_information_bottleneck
        self._use_next_obs_in_context = use_next_obs_in_context
        self._meta_batch_size = meta_batch_size
        self._num_steps_per_epoch = num_steps_per_epoch
        self._num_initial_steps = num_initial_steps
        self._num_tasks_sample = num_tasks_sample
        self._num_steps_prior = num_steps_prior
        self._num_steps_posterior = num_steps_posterior
        self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior
        self._batch_size = batch_size
        self._embedding_batch_size = embedding_batch_size
        self._embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self._discount = discount
        self._replay_buffer_size = replay_buffer_size
        self._reward_scale = reward_scale
        self._update_post_train = update_post_train
        self._task_idx = None
        self._is_resuming = False
        self._sampler = sampler
        self._optimizer_class = optimizer_class
        self._single_alpha = single_alpha

        worker_args = dict(deterministic=True, accum_context=True)
        self._evaluator = MetaEvaluator(test_task_sampler=test_env_sampler,
                                        worker_class=PEARLSAC2Worker,
                                        worker_args=worker_args,
                                        n_test_tasks=num_test_tasks)
        env_spec = env[0]()
        encoder_spec = self.get_env_spec(env_spec, latent_dim, 'encoder')
        encoder_in_dim = int(np.prod(encoder_spec.input_space.shape))
        encoder_out_dim = int(np.prod(encoder_spec.output_space.shape))
        context_encoder = encoder_class(input_dim=encoder_in_dim,
                                        output_dim=encoder_out_dim,
                                        hidden_sizes=encoder_hidden_sizes)

        # Automatic entropy coefficient tuning
        self._use_automatic_entropy_tuning = fixed_alpha is None
        self._initial_log_entropy = initial_log_entropy
        self._fixed_alpha = fixed_alpha
        if self._use_automatic_entropy_tuning:
            if target_entropy:
                self._target_entropy = target_entropy
            else:
                self._target_entropy = -np.prod(
                    env_spec.action_space.shape).item()
            if self._single_alpha:
                self._log_alpha = torch.Tensor([self._initial_log_entropy
                                                ]).requires_grad_()
            else:
                self._log_alpha = torch.Tensor(
                    [self._initial_log_entropy] *
                    self._num_train_tasks).requires_grad_()
            self._alpha_optimizer = self._optimizer_class([self._log_alpha],
                                                          lr=self._policy_lr)
        else:
            if self._single_alpha:
                self._log_alpha = torch.Tensor([self._fixed_alpha]).log()
            else:
                self._log_alpha = torch.Tensor([self._fixed_alpha] *
                                               self._num_train_tasks).log()

        self._policy = policy_class(
            latent_dim=latent_dim,
            context_encoder=context_encoder,
            policy=inner_policy,
            use_information_bottleneck=use_information_bottleneck,
            use_next_obs=use_next_obs_in_context)

        # buffer for training RL update
        self._replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(num_train_tasks)
        }
        self._context_replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self._policy_optimizer = optimizer_class(
            self._policy.networks[1].parameters(),
            lr=self._policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self._qf1.parameters(),
            lr=self._qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self._qf2.parameters(),
            lr=self._qf_lr,
        )
        self.context_optimizer = optimizer_class(
            self._policy.networks[0].parameters(),
            lr=self._context_lr,
        )
Example #22
0
    def __init__(self,
                 env,
                 inner_policy,
                 qf,
                 vf,
                 num_train_tasks,
                 num_test_tasks,
                 latent_dim,
                 encoder_hidden_sizes,
                 test_env_sampler,
                 policy_class=ContextConditionedPolicy,
                 encoder_class=MLPEncoder,
                 policy_lr=3E-4,
                 qf_lr=3E-4,
                 vf_lr=3E-4,
                 context_lr=3E-4,
                 policy_mean_reg_coeff=1E-3,
                 policy_std_reg_coeff=1E-3,
                 policy_pre_activation_coeff=0.,
                 soft_target_tau=0.005,
                 kl_lambda=.1,
                 optimizer_class=torch.optim.Adam,
                 use_information_bottleneck=True,
                 use_next_obs_in_context=False,
                 meta_batch_size=64,
                 num_steps_per_epoch=1000,
                 num_initial_steps=100,
                 num_tasks_sample=100,
                 num_steps_prior=100,
                 num_steps_posterior=0,
                 num_extra_rl_steps_posterior=100,
                 batch_size=1024,
                 embedding_batch_size=1024,
                 embedding_mini_batch_size=1024,
                 max_path_length=1000,
                 discount=0.99,
                 replay_buffer_size=1000000,
                 reward_scale=1,
                 update_post_train=1):

        self._env = env
        self._qf1 = qf
        self._qf2 = copy.deepcopy(qf)
        self._vf = vf
        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks
        self._latent_dim = latent_dim

        self._policy_mean_reg_coeff = policy_mean_reg_coeff
        self._policy_std_reg_coeff = policy_std_reg_coeff
        self._policy_pre_activation_coeff = policy_pre_activation_coeff
        self._soft_target_tau = soft_target_tau
        self._kl_lambda = kl_lambda
        self._use_information_bottleneck = use_information_bottleneck
        self._use_next_obs_in_context = use_next_obs_in_context

        self._meta_batch_size = meta_batch_size
        self._num_steps_per_epoch = num_steps_per_epoch
        self._num_initial_steps = num_initial_steps
        self._num_tasks_sample = num_tasks_sample
        self._num_steps_prior = num_steps_prior
        self._num_steps_posterior = num_steps_posterior
        self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior
        self._batch_size = batch_size
        self._embedding_batch_size = embedding_batch_size
        self._embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self._discount = discount
        self._replay_buffer_size = replay_buffer_size
        self._reward_scale = reward_scale
        self._update_post_train = update_post_train
        self._task_idx = None

        self._is_resuming = False

        worker_args = dict(deterministic=True, accum_context=True)
        self._evaluator = MetaEvaluator(test_task_sampler=test_env_sampler,
                                        max_path_length=max_path_length,
                                        worker_class=PEARLWorker,
                                        worker_args=worker_args,
                                        n_test_tasks=num_test_tasks)

        encoder_spec = self.get_env_spec(env[0](), latent_dim, 'encoder')
        encoder_in_dim = int(np.prod(encoder_spec.input_space.shape))
        encoder_out_dim = int(np.prod(encoder_spec.output_space.shape))
        context_encoder = encoder_class(input_dim=encoder_in_dim,
                                        output_dim=encoder_out_dim,
                                        hidden_sizes=encoder_hidden_sizes)

        self._policy = policy_class(
            latent_dim=latent_dim,
            context_encoder=context_encoder,
            policy=inner_policy,
            use_information_bottleneck=use_information_bottleneck,
            use_next_obs=use_next_obs_in_context)

        # buffer for training RL update
        self._replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self._context_replay_buffers = {
            i: PathBuffer(replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self.target_vf = copy.deepcopy(self._vf)
        self.vf_criterion = torch.nn.MSELoss()

        self._policy_optimizer = optimizer_class(
            self._policy.networks[1].parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self._qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self._qf2.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self._vf.parameters(),
            lr=vf_lr,
        )
        self.context_optimizer = optimizer_class(
            self._policy.networks[0].parameters(),
            lr=context_lr,
        )
Example #23
0
class PEARLSAC2(MetaRLAlgorithm):
    """A PEARL model based on https://arxiv.org/abs/1903.08254.

    PEARL, which stands for Probablistic Embeddings for Actor-Critic
    Reinforcement Learning, is an off-policy meta-RL algorithm. It is built
    on top of SAC using two Q-functions and a value function with an addition
    of an inference network that estimates the posterior :math:`q(z \| c)`.
    The policy is conditioned on the latent variable Z in order to adpat its
    behavior to specific tasks.

    Args:
        env (list[GarageEnv]): Batch of sampled environment updates(EnvUpdate),
            which, when invoked on environments, will configure them with new
            tasks.
        policy_class (garage.torch.policies.Policy): Context-conditioned policy
            class.
        encoder_class (garage.torch.embeddings.ContextEncoder): Encoder class
            for the encoder in context-conditioned policy.
        inner_policy (garage.torch.policies.Policy): Policy.
        qf (torch.nn.Module): Q-function.
        vf (torch.nn.Module): Value function.
        num_train_tasks (int): Number of tasks for training.
        num_test_tasks (int): Number of tasks for testing.
        latent_dim (int): Size of latent context vector.
        encoder_hidden_sizes (list[int]): Output dimension of dense layer(s) of
            the context encoder.
        test_env_sampler (garage.experiment.SetTaskSampler): Sampler for test
            tasks.
        policy_lr (float): Policy learning rate.
        qf_lr (float): Q-function learning rate.
        vf_lr (float): Value function learning rate.
        context_lr (float): Inference network learning rate.
        policy_mean_reg_coeff (float): Policy mean regulation weight.
        policy_std_reg_coeff (float): Policy std regulation weight.
        policy_pre_activation_coeff (float): Policy pre-activation weight.
        soft_target_tau (float): Interpolation parameter for doing the
            soft target update.
        kl_lambda (float): KL lambda value.
        optimizer_class (callable): Type of optimizer for training networks.
        use_information_bottleneck (bool): False means latent context is
            deterministic.
        use_next_obs_in_context (bool): Whether or not to use next observation
            in distinguishing between tasks.
        meta_batch_size (int): Meta batch size.
        num_steps_per_epoch (int): Number of iterations per epoch.
        num_initial_steps (int): Number of transitions obtained per task before
            training.
        num_tasks_sample (int): Number of random tasks to obtain data for each
            iteration.
        num_steps_prior (int): Number of transitions to obtain per task with
            z ~ prior.
        num_steps_posterior (int): Number of transitions to obtain per task
            with z ~ posterior.
        num_extra_rl_steps_posterior (int): Number of additional transitions
            to obtain per task with z ~ posterior that are only used to train
            the policy and NOT the encoder.
        batch_size (int): Number of transitions in RL batch.
        embedding_batch_size (int): Number of transitions in context batch.
        embedding_mini_batch_size (int): Number of transitions in mini context
            batch; should be same as embedding_batch_size for non-recurrent
            encoder.
        max_path_length (int): Maximum path length.
        discount (float): RL discount factor.
        replay_buffer_size (int): Maximum samples in replay buffer.
        reward_scale (int): Reward scale.
        update_post_train (int): How often to resample context when obtaining
            data during training (in trajectories).

    """

    # pylint: disable=too-many-statements
    def __init__(self,
                 env,
                 inner_policy,
                 qf1,
                 qf2,
                 sampler,
                 num_train_tasks,
                 num_test_tasks,
                 latent_dim,
                 encoder_hidden_sizes,
                 test_env_sampler,
                 policy_class=ContextConditionedPolicy,
                 encoder_class=MLPEncoder,
                 policy_lr=3E-4,
                 qf_lr=3E-4,
                 context_lr=3E-4,
                 policy_mean_reg_coeff=1E-3,
                 policy_std_reg_coeff=1E-3,
                 policy_pre_activation_coeff=0.,
                 soft_target_tau=0.005,
                 kl_lambda=.1,
                 fixed_alpha=None,
                 target_entropy=None,
                 initial_log_entropy=0.,
                 optimizer_class=torch.optim.Adam,
                 use_information_bottleneck=True,
                 use_next_obs_in_context=False,
                 meta_batch_size=64,
                 num_steps_per_epoch=1000,
                 num_initial_steps=100,
                 num_tasks_sample=100,
                 num_steps_prior=100,
                 num_steps_posterior=0,
                 num_extra_rl_steps_posterior=100,
                 batch_size=1024,
                 embedding_batch_size=1024,
                 embedding_mini_batch_size=1024,
                 max_path_length=500,
                 discount=0.99,
                 replay_buffer_size=1000000,
                 single_alpha=False,
                 reward_scale=1,
                 update_post_train=1):

        self._env = env
        self._qf1 = qf1
        self._qf2 = qf2
        # use 2 target q networks
        self._target_qf1 = copy.deepcopy(self._qf1)
        self._target_qf2 = copy.deepcopy(self._qf2)

        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks
        self._latent_dim = latent_dim

        self._policy_lr = policy_lr
        self._qf_lr = qf_lr
        self._context_lr = context_lr
        self._policy_mean_reg_coeff = policy_mean_reg_coeff
        self._policy_std_reg_coeff = policy_std_reg_coeff
        self._policy_pre_activation_coeff = policy_pre_activation_coeff
        self._soft_target_tau = soft_target_tau
        self._kl_lambda = kl_lambda
        self._use_information_bottleneck = use_information_bottleneck
        self._use_next_obs_in_context = use_next_obs_in_context
        self._meta_batch_size = meta_batch_size
        self._num_steps_per_epoch = num_steps_per_epoch
        self._num_initial_steps = num_initial_steps
        self._num_tasks_sample = num_tasks_sample
        self._num_steps_prior = num_steps_prior
        self._num_steps_posterior = num_steps_posterior
        self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior
        self._batch_size = batch_size
        self._embedding_batch_size = embedding_batch_size
        self._embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self._discount = discount
        self._replay_buffer_size = replay_buffer_size
        self._reward_scale = reward_scale
        self._update_post_train = update_post_train
        self._task_idx = None
        self._is_resuming = False
        self._sampler = sampler
        self._optimizer_class = optimizer_class
        self._single_alpha = single_alpha

        worker_args = dict(deterministic=True, accum_context=True)
        self._evaluator = MetaEvaluator(test_task_sampler=test_env_sampler,
                                        worker_class=PEARLSAC2Worker,
                                        worker_args=worker_args,
                                        n_test_tasks=num_test_tasks)
        env_spec = env[0]()
        encoder_spec = self.get_env_spec(env_spec, latent_dim, 'encoder')
        encoder_in_dim = int(np.prod(encoder_spec.input_space.shape))
        encoder_out_dim = int(np.prod(encoder_spec.output_space.shape))
        context_encoder = encoder_class(input_dim=encoder_in_dim,
                                        output_dim=encoder_out_dim,
                                        hidden_sizes=encoder_hidden_sizes)

        # Automatic entropy coefficient tuning
        self._use_automatic_entropy_tuning = fixed_alpha is None
        self._initial_log_entropy = initial_log_entropy
        self._fixed_alpha = fixed_alpha
        if self._use_automatic_entropy_tuning:
            if target_entropy:
                self._target_entropy = target_entropy
            else:
                self._target_entropy = -np.prod(
                    env_spec.action_space.shape).item()
            if self._single_alpha:
                self._log_alpha = torch.Tensor([self._initial_log_entropy
                                                ]).requires_grad_()
            else:
                self._log_alpha = torch.Tensor(
                    [self._initial_log_entropy] *
                    self._num_train_tasks).requires_grad_()
            self._alpha_optimizer = self._optimizer_class([self._log_alpha],
                                                          lr=self._policy_lr)
        else:
            if self._single_alpha:
                self._log_alpha = torch.Tensor([self._fixed_alpha]).log()
            else:
                self._log_alpha = torch.Tensor([self._fixed_alpha] *
                                               self._num_train_tasks).log()

        self._policy = policy_class(
            latent_dim=latent_dim,
            context_encoder=context_encoder,
            policy=inner_policy,
            use_information_bottleneck=use_information_bottleneck,
            use_next_obs=use_next_obs_in_context)

        # buffer for training RL update
        self._replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(num_train_tasks)
        }
        self._context_replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self._policy_optimizer = optimizer_class(
            self._policy.networks[1].parameters(),
            lr=self._policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self._qf1.parameters(),
            lr=self._qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self._qf2.parameters(),
            lr=self._qf_lr,
        )
        self.context_optimizer = optimizer_class(
            self._policy.networks[0].parameters(),
            lr=self._context_lr,
        )

    def __getstate__(self):
        """Object.__getstate__.

        Returns:
            dict: the state to be pickled for the instance.

        """
        data = self.__dict__.copy()
        del data['_replay_buffers']
        del data['_context_replay_buffers']
        return data

    def __setstate__(self, state):
        """Object.__setstate__.

        Args:
            state (dict): unpickled state.

        """
        self.__dict__.update(state)
        self._replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(self._num_train_tasks)
        }

        self._context_replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(self._num_train_tasks)
        }
        self._is_resuming = True

    def update_env(self, env, evaluator, num_train_tasks, num_test_tasks):
        print("Updating environments")
        self._env = env
        self._evaluator = evaluator
        self._num_train_tasks = num_train_tasks
        self._num_test_tasks = num_test_tasks
        # buffer for training RL update
        self._replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(num_train_tasks)
        }

        self._context_replay_buffers = {
            i: PathBuffer(self._replay_buffer_size)
            for i in range(num_train_tasks)
        }
        self._task_idx = 0
        print("Updated with new envipickleronment setup")
        self._policy_optimizer = torch.optim.Adam(
            self._policy.networks[1].parameters(),
            lr=3E-4,
        )
        self.qf1_optimizer = torch.optim.Adam(
            self._qf1.parameters(),
            lr=3E-4,
        )
        self.qf2_optimizer = torch.optim.Adam(
            self._qf2.parameters(),
            lr=3E-4,
        )
        self.vf_optimizer = torch.optim.Adam(
            self._vf.parameters(),
            lr=3E-4,
        )
        self.context_optimizer = torch.optim.Adam(
            self._policy.networks[0].parameters(),
            lr=3E-4,
        )

        print('Reset optimizer state')

    def fill_expert_traj(self, expert_traj_dir):
        print("Filling Expert trajectory to replay buffer")
        from os import listdir
        from os.path import isfile, join
        expert_traj_paths = [
            join(expert_traj_dir, f) for f in listdir(expert_traj_dir)
            if isfile(join(expert_traj_dir, f))
        ]
        expert_trajs = []
        for exp_path in expert_traj_paths:
            with open(exp_path, 'rb') as handle:
                data = pickle.load(handle)
                expert_trajs.append(data)

        for path in expert_trajs:
            p = {
                'observations': path['observations'],
                'actions': path['actions'],
                'rewards': path['rewards'].reshape(-1, 1),
                'next_observations': path['next_observations'],
                'dones': path['dones'].reshape(-1, 1)
            }
            self._replay_buffers[self._task_idx].add_path(p)
            self._context_replay_buffers[self._task_idx].add_path(p)

    def adapt_expert_traj(self, runner):
        """Obtain samples, train, and evaluate for each epoch.

                Args:
                    runner (LocalRunner): LocalRunner is passed to give algorithm
                        the access to runner.step_epochs(), which provides services
                        such as snapshotting and sampler control.

                """
        for _ in runner.step_epochs():
            logger.log('Adapting Policy {}...'.format(runner.step_itr))
            self._train_once()
            runner.step_itr += 1
            logger.log('Evaluating...')
            # evaluate
            self._policy.reset_belief()
            self._evaluator.evaluate(self)

    def reset(self):
        self._policy.reset_belief()

    def train(self, runner):
        """Obtain samples, train, and evaluate for each epoch.

        Args:
            runner (LocalRunner): LocalRunner is passed to give algorithm
                the access to runner.step_epochs(), which provides services
                such as snapshotting and sampler control.

        """
        for _ in runner.step_epochs():
            epoch = runner.step_itr / self._num_steps_per_epoch

            # obtain initial set of samples from all train tasks
            if epoch == 0 or self._is_resuming:
                for idx in range(self._num_train_tasks):
                    self._task_idx = idx
                    self._obtain_samples(runner, epoch,
                                         self._num_initial_steps, np.inf)
                    self._is_resuming = False

            # obtain samples from random tasks
            for _ in range(self._num_tasks_sample):
                idx = np.random.randint(self._num_train_tasks)
                self._task_idx = idx
                self._context_replay_buffers[idx].clear()
                # obtain samples with z ~ prior
                if self._num_steps_prior > 0:
                    self._obtain_samples(runner, epoch, self._num_steps_prior,
                                         np.inf)
                # obtain samples with z ~ posterior
                if self._num_steps_posterior > 0:
                    self._obtain_samples(runner, epoch,
                                         self._num_steps_posterior,
                                         self._update_post_train)
                # obtain extras samples for RL training but not encoder
                if self._num_extra_rl_steps_posterior > 0:
                    self._obtain_samples(runner,
                                         epoch,
                                         self._num_extra_rl_steps_posterior,
                                         self._update_post_train,
                                         add_to_enc_buffer=False)

            logger.log('Training...')
            # sample train tasks and optimize networks
            self._train_once()
            runner.step_itr += 1

            logger.log('Evaluating...')
            # evaluate
            self._policy.reset_belief()
            self._evaluator.evaluate(self)

    def _train_once(self):
        """Perform one iteration of training."""
        policy_loss_list = []
        qf_loss_list = []
        alpha_loss_list = []
        alpha_list = []
        for _ in range(self._num_steps_per_epoch):
            indices = np.random.choice(range(self._num_train_tasks),
                                       self._meta_batch_size)
            policy_loss, qf_loss, alpha_loss, alpha = self._optimize_policy(
                indices)
            policy_loss_list.append(policy_loss)
            qf_loss_list.append(qf_loss)
            alpha_loss_list.append(alpha_loss)
            alpha_list.append(alpha)

        with tabular.prefix('MetaTrain/Average/'):
            tabular.record('PolicyLoss',
                           np.average(np.array(policy_loss_list)))
            tabular.record('QfLoss', np.average(np.array(qf_loss_list)))
            tabular.record('AlphaLoss', np.average(np.array(alpha_loss_list)))
            tabular.record('AlphaLoss', np.average(np.array(alpha_loss_list)))
            tabular.record('Alpha', np.average(np.array(alpha_list)))

    def _optimize_policy(self, indices):
        """Perform algorithm optimizing.

        Args:
            indices (list): Tasks used for training.

        """
        num_tasks = len(indices)
        context = self._sample_context(indices)
        # clear context and reset belief of policy
        self._policy.reset_belief(num_tasks=num_tasks)
        # data shape is (task, batch, feat)
        obs, actions, rewards, next_obs, terms = self._sample_data(indices)

        # flatten out the task dimension
        t, b, _ = obs.size()
        batch_obs = obs.view(t * b, -1)
        batch_action = actions.view(t * b, -1)
        batch_next_obs = next_obs.view(t * b, -1)

        policy_outputs, task_z = self._policy(next_obs, context)
        new_next_actions, policy_mean, policy_log_std, log_pi, pre_tanh = policy_outputs

        # ===== Critic Objective =====
        with torch.no_grad():
            alpha = self._get_log_alpha(indices).exp()
        q1_pred = self._qf1(torch.cat([batch_obs, batch_action], dim=1),
                            task_z)
        q2_pred = self._qf2(torch.cat([batch_obs, batch_action], dim=1),
                            task_z)
        target_q_values = torch.min(
            self._target_qf1(
                torch.cat([batch_next_obs, new_next_actions], dim=1), task_z),
            self._target_qf2(
                torch.cat([batch_next_obs, new_next_actions], dim=1),
                task_z)).flatten() - (alpha * log_pi.flatten())

        rewards_flat = rewards.view(self._batch_size * num_tasks, -1).flatten()
        rewards_flat = rewards_flat * self._reward_scale
        terms_flat = terms.view(self._batch_size * num_tasks, -1).flatten()

        with torch.no_grad():
            q_target = rewards_flat + (
                (1. - terms_flat) * self._discount) * target_q_values

        qf1_loss = F.mse_loss(q1_pred.flatten(), q_target)
        qf2_loss = F.mse_loss(q2_pred.flatten(), q_target)
        qf_loss = qf1_loss + qf2_loss

        # KL constraint on z if probabilistic
        self.context_optimizer.zero_grad()
        if self._use_information_bottleneck:
            kl_div = self._policy.compute_kl_div()
            kl_loss = self._kl_lambda * kl_div
            kl_loss.backward(retain_graph=True)

        # Optimize Q network and context encoder
        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()
        qf_loss.backward()
        self.qf1_optimizer.step()
        self.qf2_optimizer.step()
        self.context_optimizer.step()

        # ===== Actor Objective =====
        policy_outputs, task_z = self._policy(obs, context)
        new_actions, policy_mean, policy_log_std, log_pi, pre_tanh = policy_outputs
        # compute min Q on the new actions
        q1 = self._qf1(torch.cat([batch_obs, new_actions], dim=1),
                       task_z.detach())
        q2 = self._qf2(torch.cat([batch_obs, new_actions], dim=1),
                       task_z.detach())
        min_q = torch.min(q1, q2)
        # optimize policy
        policy_loss = ((alpha * log_pi) - min_q.flatten()).mean()
        mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean()
        std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self._policy_pre_activation_coeff * (
            (pre_tanh_value**2).sum(dim=1).mean())
        policy_reg_loss = (mean_reg_loss + std_reg_loss +
                           pre_activation_reg_loss)
        policy_loss += policy_reg_loss

        self._policy_optimizer.zero_grad()
        policy_loss.backward()
        self._policy_optimizer.step()

        # ===== Temperature Objective =====
        if self._use_automatic_entropy_tuning:
            alpha = (self._get_log_alpha(indices)).exp()
            alpha_loss = (-alpha *
                          (log_pi.detach() + self._target_entropy)).mean()
            self._alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self._alpha_optimizer.step()
            alpha_avg_cpu = np.average(alpha.detach().cpu().numpy())
            alpha_loss_cpu = alpha_loss.detach().cpu().numpy()

        # ===== Update Target Network =====
        target_qfs = [self._target_qf1, self._target_qf2]
        qfs = [self._qf1, self._qf2]
        for target_qf, qf in zip(target_qfs, qfs):
            for t_param, param in zip(target_qf.parameters(), qf.parameters()):
                t_param.data.copy_(t_param.data *
                                   (1.0 - self._soft_target_tau) +
                                   param.data * self._soft_target_tau)

        qf_loss_cpu = qf_loss.detach().cpu().numpy()
        policy_loss_cpu = policy_loss.detach().cpu().numpy()
        return policy_loss_cpu, qf_loss_cpu, alpha_loss_cpu, alpha_avg_cpu

    def _obtain_samples(self,
                        runner,
                        itr,
                        num_samples,
                        update_posterior_rate,
                        add_to_enc_buffer=True):
        """Obtain samples.

        Args:
            runner (LocalRunner): LocalRunner.
            itr (int): Index of iteration (epoch).
            num_samples (int): Number of samples to obtain.
            update_posterior_rate (int): How often (in trajectories) to infer
                posterior of policy.
            add_to_enc_buffer (bool): Whether or not to add samples to encoder
                buffer.

        """
        self._policy.reset_belief()
        total_samples = 0

        if update_posterior_rate != np.inf:
            num_samples_per_batch = (update_posterior_rate *
                                     self.max_path_length)
        else:
            num_samples_per_batch = num_samples

        while total_samples < num_samples:
            paths = runner.obtain_samples(itr, num_samples_per_batch,
                                          self._policy,
                                          self._env[self._task_idx])
            total_samples += sum([len(path['rewards']) for path in paths])

            for path in paths:
                terminations = np.array([
                    step_type == StepType.TERMINAL
                    for step_type in path['step_types']
                ]).reshape(-1, 1)
                p = {
                    'observations': path['observations'],
                    'actions': path['actions'],
                    'rewards': path['rewards'].reshape(-1, 1),
                    'next_observations': path['next_observations'],
                    'dones': terminations
                }
                self._replay_buffers[self._task_idx].add_path(p)

                if add_to_enc_buffer:
                    self._context_replay_buffers[self._task_idx].add_path(p)

            if update_posterior_rate != np.inf:
                context = self._sample_context(self._task_idx)
                self._policy.infer_posterior(context)

    def _sample_data(self, indices):
        """Sample batch of training data from a list of tasks.

        Args:
            indices (list): List of task indices to sample from.

        Returns:
            torch.Tensor: Obervations, with shape :math:`(X, N, O^*)` where X
                is the number of tasks. N is batch size.
            torch.Tensor: Actions, with shape :math:`(X, N, A^*)`.
            torch.Tensor: Rewards, with shape :math:`(X, N, 1)`.
            torch.Tensor: Next obervations, with shape :math:`(X, N, O^*)`.
            torch.Tensor: Dones, with shape :math:`(X, N, 1)`.

        """
        # transitions sampled randomly from replay buffer
        initialized = False
        for idx in indices:
            batch = self._replay_buffers[idx].sample_transitions(
                self._batch_size)
            if not initialized:
                o = batch['observations'][np.newaxis]
                a = batch['actions'][np.newaxis]
                r = batch['rewards'][np.newaxis]
                no = batch['next_observations'][np.newaxis]
                d = batch['dones'][np.newaxis]
                initialized = True
            else:
                o = np.vstack((o, batch['observations'][np.newaxis]))
                a = np.vstack((a, batch['actions'][np.newaxis]))
                r = np.vstack((r, batch['rewards'][np.newaxis]))
                no = np.vstack((no, batch['next_observations'][np.newaxis]))
                d = np.vstack((d, batch['dones'][np.newaxis]))

        o = torch.as_tensor(o, device=global_device()).float()
        a = torch.as_tensor(a, device=global_device()).float()
        r = torch.as_tensor(r, device=global_device()).float()
        no = torch.as_tensor(no, device=global_device()).float()
        d = torch.as_tensor(d, device=global_device()).float()

        return o, a, r, no, d

    def _sample_context(self, indices):
        """Sample batch of context from a list of tasks.

        Args:
            indices (list): List of task indices to sample from.

        Returns:
            torch.Tensor: Context data, with shape :math:`(X, N, C)`. X is the
                number of tasks. N is batch size. C is the combined size of
                observation, action, reward, and next observation if next
                observation is used in context. Otherwise, C is the combined
                size of observation, action, and reward.

        """
        # make method work given a single task index
        if not hasattr(indices, '__iter__'):
            indices = [indices]

        initialized = False
        for idx in indices:
            batch = self._context_replay_buffers[idx].sample_transitions(
                self._embedding_batch_size)
            o = batch['observations']
            a = batch['actions']
            r = batch['rewards']
            context = np.hstack((np.hstack((o, a)), r))
            if self._use_next_obs_in_context:
                context = np.hstack((context, batch['next_observations']))

            if not initialized:
                final_context = context[np.newaxis]
                initialized = True
            else:
                final_context = np.vstack((final_context, context[np.newaxis]))

        final_context = torch.as_tensor(final_context,
                                        device=global_device()).float()
        if len(indices) == 1:
            final_context = final_context.unsqueeze(0)

        return final_context

    @property
    def policy(self):
        """Return all the policy within the model.

        Returns:
            garage.torch.policies.Policy: Policy within the model.

        """
        return self._policy

    @property
    def networks(self):
        """Return all the networks within the model.

        Returns:
            list: A list of networks.

        """
        return self._policy.networks + [
            self._policy, self._qf1, self._qf2, self._target_qf1,
            self._target_qf2
        ]

    def get_exploration_policy(self):
        """Return a policy used before adaptation to a specific task.

        Each time it is retrieved, this policy should only be evaluated in one
        task.

        Returns:
            garage.Policy: The policy used to obtain samples that are later
                used for meta-RL adaptation.

        """
        return self._policy

    def adapt_policy(self, exploration_policy, exploration_trajectories):
        """Produce a policy adapted for a task.

        Args:
            exploration_policy (garage.Policy): A policy which was returned
                from get_exploration_policy(), and which generated
                exploration_trajectories by interacting with an environment.
                The caller may not use this object after passing it into this
                method.
            exploration_trajectories (garage.TrajectoryBatch): Trajectories to
                adapt to, generated by exploration_policy exploring the
                environment.

        Returns:
            garage.Policy: A policy adapted to the task represented by the
                exploration_trajectories.

        """
        total_steps = sum(exploration_trajectories.lengths)
        o = exploration_trajectories.observations
        a = exploration_trajectories.actions
        r = exploration_trajectories.rewards.reshape(total_steps, 1)
        ctxt = np.hstack((o, a, r)).reshape(1, total_steps, -1)
        context = torch.as_tensor(ctxt, device=global_device()).float()
        self._policy.infer_posterior(context)

        return self._policy

    def to(self, device=None):
        """Put all the networks within the model on device.

        Args:
            device (str): ID of GPU or CPU.

        """
        device = device or global_device()
        for net in self.networks:
            net.to(device)

        if self._use_automatic_entropy_tuning:
            if self._single_alpha:
                self._log_alpha = torch.cuda.FloatTensor(
                    [self._initial_log_entropy],
                    device=global_device()).requires_grad_()
            else:
                self._log_alpha = torch.cuda.FloatTensor(
                    [self._initial_log_entropy] * self._num_train_tasks,
                    device=global_device()).requires_grad_()
            self._alpha_optimizer = self._optimizer_class([self._log_alpha],
                                                          lr=self._policy_lr)
        else:
            if self._single_alpha:
                self._log_alpha = torch.cuda.FloatTensor(
                    [self._fixed_alpha], device=global_device()).log()
            else:
                self._log_alpha = torch.cuda.FloatTensor(
                    [self._fixed_alpha] * self._num_train_tasks,
                    device=global_device()).log()

    @classmethod
    def augment_env_spec(cls, env_spec, latent_dim):
        """Augment environment by a size of latent dimension.

        Args:
            env_spec (garage.envs.EnvSpec): Environment specs to be augmented.
            latent_dim (int): Latent dimension.

        Returns:
            garage.envs.EnvSpec: Augmented environment specs.

        """
        obs_dim = int(np.prod(env_spec.observation_space.shape))
        action_dim = int(np.prod(env_spec.action_space.shape))
        aug_obs = akro.Box(low=-1,
                           high=1,
                           shape=(obs_dim + latent_dim, ),
                           dtype=np.float32)
        aug_act = akro.Box(low=-1,
                           high=1,
                           shape=(action_dim, ),
                           dtype=np.float32)
        return EnvSpec(aug_obs, aug_act)

    @classmethod
    def get_env_spec(cls, env_spec, latent_dim, module):
        """Get environment specs of encoder with latent dimension.

        Args:
            env_spec (garage.envs.EnvSpec): Environment specs.
            latent_dim (int): Latent dimension.
            module (str): Module to get environment specs for.

        Returns:
            garage.envs.InOutSpec: Module environment specs with latent
                dimension.

        """
        obs_dim = int(np.prod(env_spec.observation_space.shape))
        action_dim = int(np.prod(env_spec.action_space.shape))
        if module == 'encoder':
            in_dim = obs_dim + action_dim + 1
            out_dim = latent_dim * 2
        elif module == 'vf':
            in_dim = obs_dim
            out_dim = latent_dim
        in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32)
        out_space = akro.Box(low=-1,
                             high=1,
                             shape=(out_dim, ),
                             dtype=np.float32)
        if module == 'encoder':
            spec = InOutSpec(in_space, out_space)
        elif module == 'vf':
            spec = EnvSpec(in_space, out_space)

        return spec

    def _get_log_alpha(self, indices):
        """Return the value of log_alpha.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size)

        """
        log_alpha = self._log_alpha
        one_hots = np.zeros(
            (len(indices) * self._batch_size, self._num_train_tasks),
            dtype=np.float32)
        for i in range(len(indices)):
            one_hots[self._batch_size * i:self._batch_size * (i + 1),
                     indices[i]] = 1
        one_hots = torch.as_tensor(one_hots, device=global_device())
        ret = torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze()
        return ret