Example #1
0
    def test_get_actions(self, batch_size, hidden_sizes):
        """Test get_actions function."""
        env_spec = MetaRLEnv(DummyBoxEnv())
        obs_dim = env_spec.observation_space.flat_dim
        act_dim = env_spec.action_space.flat_dim
        obs = torch.ones([batch_size, obs_dim], dtype=torch.float32)
        init_std = 2.

        policy = GaussianMLPPolicy(env_spec=env_spec,
                                   hidden_sizes=hidden_sizes,
                                   init_std=init_std,
                                   hidden_nonlinearity=None,
                                   std_parameterization='exp',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)

        dist = policy(obs)

        expected_mean = torch.full([batch_size, act_dim],
                                   obs_dim *
                                   (torch.Tensor(hidden_sizes).prod().item()))
        expected_variance = init_std**2
        action, prob = policy.get_actions(obs)

        assert np.array_equal(prob['mean'], expected_mean.numpy())
        assert dist.variance.equal(
            torch.full((batch_size, act_dim), expected_variance))
        assert action.shape == (batch_size, act_dim)
Example #2
0
class TestMAML:
    """Test class for MAML."""
    def setup_method(self):
        """Setup method which is called before every test."""
        self.env = MetaRLEnv(
            normalize(HalfCheetahDirEnv(), expected_action_scale=10.))
        self.policy = GaussianMLPPolicy(
            env_spec=self.env.spec,
            hidden_sizes=(64, 64),
            hidden_nonlinearity=torch.tanh,
            output_nonlinearity=None,
        )
        self.baseline = LinearFeatureBaseline(env_spec=self.env.spec)
        self.algo = MAMLPPO(env=self.env,
                            policy=self.policy,
                            baseline=self.baseline,
                            max_path_length=100,
                            meta_batch_size=5,
                            discount=0.99,
                            gae_lambda=1.,
                            inner_lr=0.1,
                            num_grad_updates=1)

    def teardown_method(self):
        """Teardown method which is called after every test."""
        self.env.close()

    def test_get_exploration_policy(self, set_params, test_params):
        """Test if an independent copy of policy is returned."""

        self.policy.apply(partial(set_params, 0.1))
        adapt_policy = self.algo.get_exploration_policy()
        adapt_policy.apply(partial(set_params, 0.2))

        # Old policy should remain untouched
        self.policy.apply(partial(test_params, 0.1))
        adapt_policy.apply(partial(test_params, 0.2))

    def test_adapt_policy(self, set_params, test_params):
        """Test if policy can adapt to samples."""
        worker = WorkerFactory(seed=100, max_path_length=100)
        sampler = LocalSampler.from_worker_factory(worker, self.policy,
                                                   self.env)

        self.policy.apply(partial(set_params, 0.1))
        adapt_policy = self.algo.get_exploration_policy()
        trajs = sampler.obtain_samples(0, 100, adapt_policy)
        self.algo.adapt_policy(adapt_policy, trajs)

        # Old policy should remain untouched
        self.policy.apply(partial(test_params, 0.1))

        # Adapted policy should not be identical to old policy
        for v1, v2 in zip(adapt_policy.parameters(), self.policy.parameters()):
            if v1.data.ne(v2.data).sum() > 0:
                break
        else:
            pytest.fail("Parameters of adapted policy should not be "
                        "identical to the old policy.")
Example #3
0
 def test_entropy(self):
     """Test get_entropy method of the policy."""
     env_spec = MetaRLEnv(DummyBoxEnv())
     init_std = 1.
     obs = torch.Tensor([0, 0, 0, 0]).float()
     policy = GaussianMLPPolicy(env_spec=env_spec,
                                hidden_sizes=(1, ),
                                init_std=init_std,
                                hidden_nonlinearity=None,
                                std_parameterization='exp',
                                hidden_w_init=nn.init.ones_,
                                output_w_init=nn.init.ones_)
     dist = policy(obs)
     assert torch.allclose(dist.entropy(), policy.entropy(obs))
def ppo_pendulum(ctxt=None, seed=1):
    """Train PPO with InvertedDoublePendulum-v2 environment.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    env = MetaRLEnv(env_name='InvertedDoublePendulum-v2')

    runner = LocalRunner(ctxt)

    policy = GaussianMLPPolicy(env.spec,
                               hidden_sizes=[64, 64],
                               hidden_nonlinearity=torch.tanh,
                               output_nonlinearity=None)

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    algo = PPO(env_spec=env.spec,
               policy=policy,
               value_function=value_function,
               max_path_length=100,
               discount=0.99,
               center_adv=False)

    runner.setup(algo, env)
    runner.train(n_epochs=100, batch_size=10000)
def test_maml_trpo_dummy_named_env():
    """Test with dummy environment that has env_name."""
    env = MetaRLEnv(
        normalize(DummyMultiTaskBoxEnv(), expected_action_scale=10.))
    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )
    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32))

    rollouts_per_task = 2
    max_path_length = 100

    runner = LocalRunner(snapshot_config)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    value_function=value_function,
                    max_path_length=max_path_length,
                    meta_batch_size=5,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1)

    runner.setup(algo, env)
    runner.train(n_epochs=2, batch_size=rollouts_per_task * max_path_length)
def test_maml_trpo_pendulum():
    """Test PPO with Pendulum environment."""
    env = MetaRLEnv(normalize(HalfCheetahDirEnv(), expected_action_scale=10.))
    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )
    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32))

    rollouts_per_task = 5
    max_path_length = 100

    runner = LocalRunner(snapshot_config)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    value_function=value_function,
                    max_path_length=max_path_length,
                    meta_batch_size=5,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1)

    runner.setup(algo, env)
    last_avg_ret = runner.train(n_epochs=5,
                                batch_size=rollouts_per_task * max_path_length)

    assert last_avg_ret > -5

    env.close()
Example #7
0
def run_task(snapshot_config, *_):
    """Set up environment and algorithm and run the task.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        _ : Unused parameters

    """
    env = TfEnv(env_name='InvertedDoublePendulum-v2')

    runner = LocalRunner(snapshot_config)

    policy = GaussianMLPPolicy(env.spec,
                               hidden_sizes=[32, 32],
                               hidden_nonlinearity=torch.tanh,
                               output_nonlinearity=None)

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = TRPO(env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                max_path_length=100,
                discount=0.99,
                center_adv=False)

    runner.setup(algo, env)
    runner.train(n_epochs=100, batch_size=1024)
Example #8
0
def maml_trpo_metaworld_ml10(ctxt, seed, epochs, rollouts_per_task,
                             meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        rollouts_per_task (int): Number of rollouts per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    env = MetaRLEnv(
        normalize(mwb.ML10.get_train_tasks(), expected_action_scale=10.))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(100, 100),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    max_path_length = 100

    test_task_names = mwb.ML10.get_test_tasks().all_task_names
    test_tasks = [
        MetaRLEnv(
            normalize(mwb.ML10.from_task(task), expected_action_scale=10.))
        for task in test_task_names
    ]
    test_sampler = EnvPoolSampler(test_tasks)

    meta_evaluator = MetaEvaluator(test_task_sampler=test_sampler,
                                   max_path_length=max_path_length,
                                   n_test_tasks=len(test_task_names))

    runner = LocalRunner(ctxt)
    algo = MAMLTRPO(env=env,
                    policy=policy,
                    value_function=value_function,
                    max_path_length=max_path_length,
                    meta_batch_size=meta_batch_size,
                    discount=0.99,
                    gae_lambda=1.,
                    inner_lr=0.1,
                    num_grad_updates=1,
                    meta_evaluator=meta_evaluator)

    runner.setup(algo, env)
    runner.train(n_epochs=epochs,
                 batch_size=rollouts_per_task * max_path_length)
 def setup_method(self):
     """Setup method which is called before every test."""
     self.env = MetaRLEnv(normalize(gym.make('InvertedDoublePendulum-v2')))
     self.policy = GaussianMLPPolicy(
         env_spec=self.env.spec,
         hidden_sizes=(64, 64),
         hidden_nonlinearity=torch.tanh,
         output_nonlinearity=None,
     )
     self.value_function = GaussianMLPValueFunction(env_spec=self.env.spec)
def maml_vpg_half_cheetah_dir(ctxt, seed, epochs, rollouts_per_task,
                              meta_batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        rollouts_per_task (int): Number of rollouts per epoch per task
            for training.
        meta_batch_size (int): Number of tasks sampled per batch.

    """
    set_seed(seed)
    env = MetaRLEnv(normalize(HalfCheetahDirEnv(), expected_action_scale=10.))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    max_path_length = 100

    task_sampler = SetTaskSampler(lambda: MetaRLEnv(
        normalize(HalfCheetahDirEnv(), expected_action_scale=10.)))

    meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler,
                                   max_path_length=max_path_length,
                                   n_test_tasks=1,
                                   n_test_rollouts=10)

    runner = LocalRunner(ctxt)
    algo = MAMLVPG(env=env,
                   policy=policy,
                   value_function=value_function,
                   max_path_length=max_path_length,
                   meta_batch_size=meta_batch_size,
                   discount=0.99,
                   gae_lambda=1.,
                   inner_lr=0.1,
                   num_grad_updates=1,
                   meta_evaluator=meta_evaluator)

    runner.setup(algo, env)
    runner.train(n_epochs=epochs,
                 batch_size=rollouts_per_task * max_path_length)
Example #11
0
 def setup_method(self):
     """Setup method which is called before every test."""
     self.env = MetaRLEnv(
         normalize(HalfCheetahDirEnv(), expected_action_scale=10.))
     self.policy = GaussianMLPPolicy(
         env_spec=self.env.spec,
         hidden_sizes=(64, 64),
         hidden_nonlinearity=torch.tanh,
         output_nonlinearity=None,
     )
     self.baseline = LinearFeatureBaseline(env_spec=self.env.spec)
Example #12
0
 def setup_method(self):
     """Setup method which is called before every test."""
     self.env = MetaRLEnv(
         normalize(HalfCheetahDirEnv(), expected_action_scale=10.))
     self.policy = GaussianMLPPolicy(
         env_spec=self.env.spec,
         hidden_sizes=(64, 64),
         hidden_nonlinearity=torch.tanh,
         output_nonlinearity=None,
     )
     self.baseline = LinearFeatureBaseline(env_spec=self.env.spec)
     self.algo = MAMLPPO(env=self.env,
                         policy=self.policy,
                         baseline=self.baseline,
                         max_path_length=100,
                         meta_batch_size=5,
                         discount=0.99,
                         gae_lambda=1.,
                         inner_lr=0.1,
                         num_grad_updates=1)
Example #13
0
def run_metarl(env, seed, log_dir):
    """Create metarl PyTorch MAML model and training.

    Args:
        env (MetaRLEnv): Environment of the task.
        seed (int): Random positive integer for the trial.
        log_dir (str): Log dir path.

    Returns:
        str: Path to output csv file

    """
    deterministic.set_seed(seed)

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=hyper_parameters['hidden_sizes'],
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = MAMLTRPO(env=env,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=hyper_parameters['max_path_length'],
                    discount=hyper_parameters['discount'],
                    gae_lambda=hyper_parameters['gae_lambda'],
                    meta_batch_size=hyper_parameters['meta_batch_size'],
                    inner_lr=hyper_parameters['inner_lr'],
                    max_kl_step=hyper_parameters['max_kl'],
                    num_grad_updates=hyper_parameters['num_grad_update'])

    # Set up logger since we are not using run_experiment
    tabular_log_file = osp.join(log_dir, 'progress.csv')
    dowel_logger.add_output(dowel.StdOutput())
    dowel_logger.add_output(dowel.CsvOutput(tabular_log_file))
    dowel_logger.add_output(dowel.TensorBoardOutput(log_dir))

    snapshot_config = SnapshotConfig(snapshot_dir=log_dir,
                                     snapshot_mode='all',
                                     snapshot_gap=1)

    runner = LocalRunner(snapshot_config=snapshot_config)
    runner.setup(algo, env, sampler_args=dict(n_envs=5))
    runner.train(n_epochs=hyper_parameters['n_epochs'],
                 batch_size=(hyper_parameters['fast_batch_size'] *
                             hyper_parameters['max_path_length']))

    dowel_logger.remove_all()

    return tabular_log_file
    def test_is_pickleable(self, batch_size, hidden_sizes):
        env_spec = TfEnv(DummyBoxEnv())
        obs_dim = env_spec.observation_space.flat_dim
        obs = torch.ones([batch_size, obs_dim], dtype=torch.float32)
        init_std = 2.

        policy = GaussianMLPPolicy(env_spec=env_spec,
                                   hidden_sizes=hidden_sizes,
                                   init_std=init_std,
                                   hidden_nonlinearity=None,
                                   std_parameterization='exp',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)

        output1_action, output1_prob = policy.get_actions(obs)

        p = pickle.dumps(policy)
        policy_pickled = pickle.loads(p)
        output2_action, output2_prob = policy_pickled.get_actions(obs)

        assert np.array_equal(output1_prob['mean'], output2_prob['mean'])
        assert output1_action.shape == output2_action.shape
Example #15
0
    def setup_method(self):
        """Setup method which is called before every test."""
        self._env = MetaRLEnv(gym.make('InvertedDoublePendulum-v2'))
        self._runner = LocalRunner(snapshot_config)

        self._policy = GaussianMLPPolicy(env_spec=self._env.spec,
                                         hidden_sizes=[64, 64],
                                         hidden_nonlinearity=torch.tanh,
                                         output_nonlinearity=None)
        self._params = {
            'env_spec': self._env.spec,
            'policy': self._policy,
            'baseline': LinearFeatureBaseline(env_spec=self._env.spec),
            'max_path_length': 100,
            'discount': 0.99,
        }
def mtppo_metaworld_mt10(ctxt, seed, epochs, batch_size, n_worker):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        batch_size (int): Number of environment steps in one batch.
        n_worker (int): The number of workers the sampler should use.

    """
    set_seed(seed)
    tasks = mwb.MT10.get_train_tasks().all_task_names
    envs = []
    for task in tasks:
        envs.append(normalize(MetaRLEnv(mwb.MT10.from_task(task))))
    env = MultiEnvWrapper(envs,
                          sample_strategy=round_robin_strategy,
                          mode='vanilla')

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    algo = PPO(env_spec=env.spec,
               policy=policy,
               value_function=value_function,
               max_path_length=128,
               discount=0.99,
               gae_lambda=0.95,
               center_adv=True,
               lr_clip_range=0.2)

    runner = LocalRunner(ctxt)
    runner.setup(algo, env, n_workers=n_worker)
    runner.train(n_epochs=epochs, batch_size=batch_size)
Example #17
0
def trpo_pendulum_ray_sampler(ctxt=None, seed=1):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    # Since this is an example, we are running ray in a reduced state.
    # One can comment this line out in order to run ray at full capacity
    ray.init(memory=52428800,
             object_store_memory=78643200,
             ignore_reinit_error=True,
             log_to_driver=False,
             include_webui=False)
    deterministic.set_seed(seed)
    env = MetaRLEnv(env_name='InvertedDoublePendulum-v2')

    runner = LocalRunner(ctxt)

    policy = GaussianMLPPolicy(env.spec,
                               hidden_sizes=[32, 32],
                               hidden_nonlinearity=torch.tanh,
                               output_nonlinearity=None)

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    algo = TRPO(env_spec=env.spec,
                policy=policy,
                value_function=value_function,
                max_path_length=100,
                discount=0.99,
                center_adv=False)

    runner.setup(algo, env, sampler_cls=RaySampler)
    runner.train(n_epochs=100, batch_size=1024)
Example #18
0
def run_task(snapshot_config, *_):
    """Set up environment and algorithm and run the task.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        _ : Unused parameters

    """
    env = MetaRLEnv(normalize(HalfCheetahDirEnv(), expected_action_scale=10.))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    rollouts_per_task = 20
    max_path_length = 100

    runner = LocalRunner(snapshot_config)
    algo = MAMLVPG(env=env,
                   policy=policy,
                   baseline=baseline,
                   max_path_length=max_path_length,
                   meta_batch_size=40,
                   discount=0.99,
                   gae_lambda=1.,
                   inner_lr=0.1,
                   num_grad_updates=1)

    runner.setup(algo, env)
    runner.train(n_epochs=300, batch_size=rollouts_per_task * max_path_length)
Example #19
0
def mtppo_metaworld_ml1_push(ctxt, seed, epochs, batch_size):
    """Set up environment and algorithm and run the task.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        epochs (int): Number of training epochs.
        batch_size (int): Number of environment steps in one batch.

    """
    set_seed(seed)
    env = MetaRLEnv(normalize(mwb.ML1.get_train_tasks('push-v1')))

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_nonlinearity=torch.tanh,
        output_nonlinearity=None,
    )

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    algo = PPO(env_spec=env.spec,
               policy=policy,
               value_function=value_function,
               max_path_length=128,
               discount=0.99,
               gae_lambda=0.95,
               center_adv=True,
               lr_clip_range=0.2)

    runner = LocalRunner(ctxt)
    runner.setup(algo, env)
    runner.train(n_epochs=epochs, batch_size=batch_size)
Example #20
0
class TestMAML:
    """Test class for MAML."""
    def setup_method(self):
        """Setup method which is called before every test."""
        self.env = MetaRLEnv(
            normalize(HalfCheetahDirEnv(), expected_action_scale=10.))
        self.policy = GaussianMLPPolicy(
            env_spec=self.env.spec,
            hidden_sizes=(64, 64),
            hidden_nonlinearity=torch.tanh,
            output_nonlinearity=None,
        )
        self.value_function = GaussianMLPValueFunction(env_spec=self.env.spec,
                                                       hidden_sizes=(32, 32))
        self.algo = MAMLPPO(env=self.env,
                            policy=self.policy,
                            value_function=self.value_function,
                            max_path_length=100,
                            meta_batch_size=5,
                            discount=0.99,
                            gae_lambda=1.,
                            inner_lr=0.1,
                            num_grad_updates=1)

    def teardown_method(self):
        """Teardown method which is called after every test."""
        self.env.close()

    @staticmethod
    def _set_params(v, m):
        """Set the parameters of a module to a value."""
        if isinstance(m, torch.nn.Linear):
            m.weight.data.fill_(v)
            m.bias.data.fill_(v)

    @staticmethod
    def _test_params(v, m):
        """Test if all parameters of a module equal to a value."""
        if isinstance(m, torch.nn.Linear):
            assert torch.all(torch.eq(m.weight.data, v))
            assert torch.all(torch.eq(m.bias.data, v))

    def test_get_exploration_policy(self):
        """Test if an independent copy of policy is returned."""
        self.policy.apply(partial(self._set_params, 0.1))
        adapt_policy = self.algo.get_exploration_policy()
        adapt_policy.apply(partial(self._set_params, 0.2))

        # Old policy should remain untouched
        self.policy.apply(partial(self._test_params, 0.1))
        adapt_policy.apply(partial(self._test_params, 0.2))

    def test_adapt_policy(self):
        """Test if policy can adapt to samples."""
        worker = WorkerFactory(seed=100, max_path_length=100)
        sampler = LocalSampler.from_worker_factory(worker, self.policy,
                                                   self.env)

        self.policy.apply(partial(self._set_params, 0.1))
        adapt_policy = self.algo.get_exploration_policy()
        trajs = sampler.obtain_samples(0, 100, adapt_policy)
        self.algo.adapt_policy(adapt_policy, trajs)

        # Old policy should remain untouched
        self.policy.apply(partial(self._test_params, 0.1))

        # Adapted policy should not be identical to old policy
        for v1, v2 in zip(adapt_policy.parameters(), self.policy.parameters()):
            if v1.data.ne(v2.data).sum() > 0:
                break
        else:
            pytest.fail('Parameters of adapted policy should not be '
                        'identical to the old policy.')
    def test_module(self, reward_dim, latent_dim, hidden_sizes, updates):
        """Test all methods."""
        env_spec = TfEnv(DummyBoxEnv())
        latent_space = akro.Box(low=-1,
                                high=1,
                                shape=(latent_dim, ),
                                dtype=np.float32)

        # add latent space to observation space to create a new space
        augmented_obs_space = akro.Tuple(
            (env_spec.observation_space, latent_space))
        augmented_env_spec = EnvSpec(augmented_obs_space,
                                     env_spec.action_space)

        obs_dim = int(np.prod(env_spec.observation_space.shape))
        action_dim = int(np.prod(env_spec.action_space.shape))
        encoder_input_dim = obs_dim + action_dim + reward_dim
        encoder_output_dim = latent_dim * 2
        encoder_hidden_sizes = (3, 2, encoder_output_dim)

        context_encoder = RecurrentEncoder(input_dim=encoder_input_dim,
                                           output_dim=encoder_output_dim,
                                           hidden_nonlinearity=None,
                                           hidden_sizes=encoder_hidden_sizes,
                                           hidden_w_init=nn.init.ones_,
                                           output_w_init=nn.init.ones_)

        # policy needs to be able to accept obs_dim + latent_dim as input dim
        policy = GaussianMLPPolicy(env_spec=augmented_env_spec,
                                   hidden_sizes=hidden_sizes,
                                   hidden_nonlinearity=F.relu,
                                   output_nonlinearity=None)

        module = ContextConditionedPolicy(latent_dim=latent_dim,
                                          context_encoder=context_encoder,
                                          policy=policy,
                                          use_ib=True,
                                          use_next_obs=False)

        expected_shape = [1, latent_dim]
        module.reset_belief()
        assert torch.all(torch.eq(module.z_means, torch.zeros(expected_shape)))
        assert torch.all(torch.eq(module.z_vars, torch.ones(expected_shape)))

        module.sample_from_belief()
        assert all([a == b for a, b in zip(module.z.shape, expected_shape)])

        module.detach_z()
        assert module.z.requires_grad is False

        context_dict = {}
        context_dict['observation'] = np.ones(obs_dim)
        context_dict['action'] = np.ones(action_dim)
        context_dict['reward'] = np.ones(reward_dim)
        context_dict['next_observation'] = np.ones(obs_dim)

        for _ in range(updates):
            module.update_context(context_dict)
        assert torch.all(
            torch.eq(module._context, torch.ones(updates, encoder_input_dim)))

        context = torch.randn(1, 1, encoder_input_dim)
        module.infer_posterior(context)
        assert all([a == b for a, b in zip(module.z.shape, expected_shape)])

        t, b = 1, 2
        obs = torch.randn((t, b, obs_dim), dtype=torch.float32)
        policy_output, task_z_out = module.forward(obs, context)
        assert policy_output is not None
        expected_shape = [b, latent_dim]
        assert all([a == b for a, b in zip(task_z_out.shape, expected_shape)])

        obs = torch.randn(obs_dim)
        action = module.get_action(obs)
        assert len(action) == action_dim

        kl_div = module.compute_kl_div()
        assert kl_div != 0