def setup_method(self):
        """Setup for all test methods."""
        self.latent_dim = 5
        self.env_spec = GarageEnv(DummyBoxEnv())
        latent_space = akro.Box(low=-1,
                                high=1,
                                shape=(self.latent_dim, ),
                                dtype=np.float32)

        # add latent space to observation space to create a new space
        augmented_obs_space = akro.Tuple(
            (self.env_spec.observation_space, latent_space))
        augmented_env_spec = EnvSpec(augmented_obs_space,
                                     self.env_spec.action_space)

        self.obs_dim = int(np.prod(self.env_spec.observation_space.shape))
        self.action_dim = int(np.prod(self.env_spec.action_space.shape))
        reward_dim = 1
        self.encoder_input_dim = self.obs_dim + self.action_dim + reward_dim
        encoder_output_dim = self.latent_dim * 2
        encoder_hidden_sizes = (3, 2, encoder_output_dim)

        context_encoder = MLPEncoder(input_dim=self.encoder_input_dim,
                                     output_dim=encoder_output_dim,
                                     hidden_nonlinearity=None,
                                     hidden_sizes=encoder_hidden_sizes,
                                     hidden_w_init=nn.init.ones_,
                                     output_w_init=nn.init.ones_)

        context_policy = TanhGaussianMLPPolicy(env_spec=augmented_env_spec,
                                               hidden_sizes=(3, 5, 7),
                                               hidden_nonlinearity=F.relu,
                                               output_nonlinearity=None)

        self.module = ContextConditionedPolicy(latent_dim=self.latent_dim,
                                               context_encoder=context_encoder,
                                               policy=context_policy,
                                               use_information_bottleneck=True,
                                               use_next_obs=False)
Ejemplo n.º 2
0
def test_methods():
    """Test PEARLWorker methods."""
    env_spec = GarageEnv(DummyBoxEnv())
    latent_dim = 5
    latent_space = akro.Box(low=-1,
                            high=1,
                            shape=(latent_dim, ),
                            dtype=np.float32)

    # add latent space to observation space to create a new space
    augmented_obs_space = akro.Tuple(
        (env_spec.observation_space, latent_space))
    augmented_env_spec = EnvSpec(augmented_obs_space, env_spec.action_space)

    obs_dim = int(np.prod(env_spec.observation_space.shape))
    action_dim = int(np.prod(env_spec.action_space.shape))
    reward_dim = 1
    encoder_input_dim = obs_dim + action_dim + reward_dim
    encoder_output_dim = latent_dim * 2
    encoder_hidden_sizes = (3, 2, encoder_output_dim)

    context_encoder = MLPEncoder(input_dim=encoder_input_dim,
                                 output_dim=encoder_output_dim,
                                 hidden_nonlinearity=None,
                                 hidden_sizes=encoder_hidden_sizes,
                                 hidden_w_init=nn.init.ones_,
                                 output_w_init=nn.init.ones_)

    policy = TanhGaussianMLPPolicy(env_spec=augmented_env_spec,
                                   hidden_sizes=(3, 5, 7),
                                   hidden_nonlinearity=F.relu,
                                   output_nonlinearity=None)

    context_policy = ContextConditionedPolicy(latent_dim=latent_dim,
                                              context_encoder=context_encoder,
                                              policy=policy,
                                              use_information_bottleneck=True,
                                              use_next_obs=False)

    max_path_length = 20
    worker1 = PEARLWorker(seed=1,
                          max_path_length=max_path_length,
                          worker_number=1)
    worker1.update_agent(context_policy)
    worker1.update_env(env_spec)
    rollouts = worker1.rollout()

    assert rollouts.observations.shape == (max_path_length, obs_dim)
    assert rollouts.actions.shape == (max_path_length, action_dim)
    assert rollouts.rewards.shape == (max_path_length, )

    worker2 = PEARLWorker(seed=1,
                          max_path_length=max_path_length,
                          worker_number=1,
                          deterministic=True,
                          accum_context=True)
    worker2.update_agent(context_policy)
    worker2.update_env(env_spec)
    rollouts = worker2.rollout()

    assert context_policy.context.shape == (1, max_path_length,
                                            encoder_input_dim)
    assert rollouts.observations.shape == (max_path_length, obs_dim)
    assert rollouts.actions.shape == (max_path_length, action_dim)
    assert rollouts.rewards.shape == (max_path_length, )
    def test_module(self, reward_dim, latent_dim, hidden_sizes, updates):
        """Test all methods."""
        env_spec = TfEnv(DummyBoxEnv())
        latent_space = akro.Box(low=-1,
                                high=1,
                                shape=(latent_dim, ),
                                dtype=np.float32)

        # add latent space to observation space to create a new space
        augmented_obs_space = akro.Tuple(
            (env_spec.observation_space, latent_space))
        augmented_env_spec = EnvSpec(augmented_obs_space,
                                     env_spec.action_space)

        obs_dim = int(np.prod(env_spec.observation_space.shape))
        action_dim = int(np.prod(env_spec.action_space.shape))
        encoder_input_dim = obs_dim + action_dim + reward_dim
        encoder_output_dim = latent_dim * 2
        encoder_hidden_sizes = (3, 2, encoder_output_dim)

        context_encoder = RecurrentEncoder(input_dim=encoder_input_dim,
                                           output_dim=encoder_output_dim,
                                           hidden_nonlinearity=None,
                                           hidden_sizes=encoder_hidden_sizes,
                                           hidden_w_init=nn.init.ones_,
                                           output_w_init=nn.init.ones_)

        # policy needs to be able to accept obs_dim + latent_dim as input dim
        policy = GaussianMLPPolicy(env_spec=augmented_env_spec,
                                   hidden_sizes=hidden_sizes,
                                   hidden_nonlinearity=F.relu,
                                   output_nonlinearity=None)

        module = ContextConditionedPolicy(latent_dim=latent_dim,
                                          context_encoder=context_encoder,
                                          policy=policy,
                                          use_ib=True,
                                          use_next_obs=False)

        expected_shape = [1, latent_dim]
        module.reset_belief()
        assert torch.all(torch.eq(module.z_means, torch.zeros(expected_shape)))
        assert torch.all(torch.eq(module.z_vars, torch.ones(expected_shape)))

        module.sample_from_belief()
        assert all([a == b for a, b in zip(module.z.shape, expected_shape)])

        module.detach_z()
        assert module.z.requires_grad is False

        context_dict = {}
        context_dict['observation'] = np.ones(obs_dim)
        context_dict['action'] = np.ones(action_dim)
        context_dict['reward'] = np.ones(reward_dim)
        context_dict['next_observation'] = np.ones(obs_dim)

        for _ in range(updates):
            module.update_context(context_dict)
        assert torch.all(
            torch.eq(module._context, torch.ones(updates, encoder_input_dim)))

        context = torch.randn(1, 1, encoder_input_dim)
        module.infer_posterior(context)
        assert all([a == b for a, b in zip(module.z.shape, expected_shape)])

        t, b = 1, 2
        obs = torch.randn((t, b, obs_dim), dtype=torch.float32)
        policy_output, task_z_out = module.forward(obs, context)
        assert policy_output is not None
        expected_shape = [b, latent_dim]
        assert all([a == b for a, b in zip(task_z_out.shape, expected_shape)])

        obs = torch.randn(obs_dim)
        action = module.get_action(obs)
        assert len(action) == action_dim

        kl_div = module.compute_kl_div()
        assert kl_div != 0