def test_te_worker(self): worker = TaskEmbeddingWorker(seed=1, max_path_length=100, worker_number=1) worker.update_env(self.env) worker.update_agent(self.policy) worker.start_rollout() while not worker.step_rollout(): pass paths = worker.collect_rollout() assert 'task_onehot' in paths.env_infos.keys() assert paths.env_infos['task_onehot'][0].shape == (4, ) assert 'latent' in paths.agent_infos.keys() assert paths.agent_infos['latent'][0].shape == (1, )
def test_te_worker(self): worker = TaskEmbeddingWorker( seed=1, max_episode_length=self.max_episode_length, worker_number=1) worker.update_env(self.env) worker.update_agent(self.policy) worker.start_episode() while not worker.step_episode(): pass episodes = worker.collect_episode() assert 'task_onehot' in episodes.env_infos.keys() assert episodes.env_infos['task_onehot'][0].shape == (4, ) assert 'latent' in episodes.agent_infos.keys() assert episodes.agent_infos['latent'][0].shape == (1, )
def test_task_embedding_worker(self): env = GarageEnv(DummyBoxEnv(obs_dim=(1, ))) env.active_task_one_hot = np.array([1., 0., 0., 0.]) env._active_task_one_hot = lambda: np.array([1., 0., 0., 0.]) a = np.random.random(env.action_space.shape) z = np.random.random(5) latent_info = dict(mean=np.random.random(5)) agent_info = dict(dummy='dummy') policy = Mock() policy.get_latent.return_value = (z, latent_info) policy.latent_space.flatten.return_value = z policy.get_action_given_latent.return_value = (a, agent_info) worker = TaskEmbeddingWorker(seed=1, max_path_length=100, worker_number=1) worker.update_agent(policy) worker.update_env(env) rollouts = worker.rollout() assert 'task_onehot' in rollouts.env_infos assert np.array_equal(rollouts.env_infos['task_onehot'][0], env.active_task_one_hot) assert 'latent' in rollouts.agent_infos assert np.array_equal(rollouts.agent_infos['latent'][0], z) assert 'latent_mean' in rollouts.agent_infos assert np.array_equal(rollouts.agent_infos['latent_mean'][0], latent_info['mean'])