def setup_method(self): """Setup for all test methods.""" self.latent_dim = 5 self.env_spec = GarageEnv(DummyBoxEnv()) latent_space = akro.Box(low=-1, high=1, shape=(self.latent_dim, ), dtype=np.float32) # add latent space to observation space to create a new space augmented_obs_space = akro.Tuple( (self.env_spec.observation_space, latent_space)) augmented_env_spec = EnvSpec(augmented_obs_space, self.env_spec.action_space) self.obs_dim = int(np.prod(self.env_spec.observation_space.shape)) self.action_dim = int(np.prod(self.env_spec.action_space.shape)) reward_dim = 1 self.encoder_input_dim = self.obs_dim + self.action_dim + reward_dim encoder_output_dim = self.latent_dim * 2 encoder_hidden_sizes = (3, 2, encoder_output_dim) context_encoder = MLPEncoder(input_dim=self.encoder_input_dim, output_dim=encoder_output_dim, hidden_nonlinearity=None, hidden_sizes=encoder_hidden_sizes, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) context_policy = TanhGaussianMLPPolicy(env_spec=augmented_env_spec, hidden_sizes=(3, 5, 7), hidden_nonlinearity=F.relu, output_nonlinearity=None) self.module = ContextConditionedPolicy(latent_dim=self.latent_dim, context_encoder=context_encoder, policy=context_policy, use_information_bottleneck=True, use_next_obs=False)
def test_methods(): """Test PEARLWorker methods.""" env_spec = GarageEnv(DummyBoxEnv()) latent_dim = 5 latent_space = akro.Box(low=-1, high=1, shape=(latent_dim, ), dtype=np.float32) # add latent space to observation space to create a new space augmented_obs_space = akro.Tuple( (env_spec.observation_space, latent_space)) augmented_env_spec = EnvSpec(augmented_obs_space, env_spec.action_space) obs_dim = int(np.prod(env_spec.observation_space.shape)) action_dim = int(np.prod(env_spec.action_space.shape)) reward_dim = 1 encoder_input_dim = obs_dim + action_dim + reward_dim encoder_output_dim = latent_dim * 2 encoder_hidden_sizes = (3, 2, encoder_output_dim) context_encoder = MLPEncoder(input_dim=encoder_input_dim, output_dim=encoder_output_dim, hidden_nonlinearity=None, hidden_sizes=encoder_hidden_sizes, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) policy = TanhGaussianMLPPolicy(env_spec=augmented_env_spec, hidden_sizes=(3, 5, 7), hidden_nonlinearity=F.relu, output_nonlinearity=None) context_policy = ContextConditionedPolicy(latent_dim=latent_dim, context_encoder=context_encoder, policy=policy, use_information_bottleneck=True, use_next_obs=False) max_path_length = 20 worker1 = PEARLWorker(seed=1, max_path_length=max_path_length, worker_number=1) worker1.update_agent(context_policy) worker1.update_env(env_spec) rollouts = worker1.rollout() assert rollouts.observations.shape == (max_path_length, obs_dim) assert rollouts.actions.shape == (max_path_length, action_dim) assert rollouts.rewards.shape == (max_path_length, ) worker2 = PEARLWorker(seed=1, max_path_length=max_path_length, worker_number=1, deterministic=True, accum_context=True) worker2.update_agent(context_policy) worker2.update_env(env_spec) rollouts = worker2.rollout() assert context_policy.context.shape == (1, max_path_length, encoder_input_dim) assert rollouts.observations.shape == (max_path_length, obs_dim) assert rollouts.actions.shape == (max_path_length, action_dim) assert rollouts.rewards.shape == (max_path_length, )