Beispiel #1
0
 def test_te_worker(self):
     worker = TaskEmbeddingWorker(seed=1,
                                  max_path_length=100,
                                  worker_number=1)
     worker.update_env(self.env)
     worker.update_agent(self.policy)
     worker.start_rollout()
     while not worker.step_rollout():
         pass
     paths = worker.collect_rollout()
     assert 'task_onehot' in paths.env_infos.keys()
     assert paths.env_infos['task_onehot'][0].shape == (4, )
     assert 'latent' in paths.agent_infos.keys()
     assert paths.agent_infos['latent'][0].shape == (1, )
Beispiel #2
0
 def test_te_worker(self):
     worker = TaskEmbeddingWorker(
         seed=1,
         max_episode_length=self.max_episode_length,
         worker_number=1)
     worker.update_env(self.env)
     worker.update_agent(self.policy)
     worker.start_episode()
     while not worker.step_episode():
         pass
     episodes = worker.collect_episode()
     assert 'task_onehot' in episodes.env_infos.keys()
     assert episodes.env_infos['task_onehot'][0].shape == (4, )
     assert 'latent' in episodes.agent_infos.keys()
     assert episodes.agent_infos['latent'][0].shape == (1, )
Beispiel #3
0
    def test_task_embedding_worker(self):
        env = GarageEnv(DummyBoxEnv(obs_dim=(1, )))
        env.active_task_one_hot = np.array([1., 0., 0., 0.])
        env._active_task_one_hot = lambda: np.array([1., 0., 0., 0.])

        a = np.random.random(env.action_space.shape)
        z = np.random.random(5)
        latent_info = dict(mean=np.random.random(5))
        agent_info = dict(dummy='dummy')

        policy = Mock()
        policy.get_latent.return_value = (z, latent_info)
        policy.latent_space.flatten.return_value = z
        policy.get_action_given_latent.return_value = (a, agent_info)

        worker = TaskEmbeddingWorker(seed=1,
                                     max_path_length=100,
                                     worker_number=1)
        worker.update_agent(policy)
        worker.update_env(env)

        rollouts = worker.rollout()
        assert 'task_onehot' in rollouts.env_infos
        assert np.array_equal(rollouts.env_infos['task_onehot'][0],
                              env.active_task_one_hot)
        assert 'latent' in rollouts.agent_infos
        assert np.array_equal(rollouts.agent_infos['latent'][0], z)
        assert 'latent_mean' in rollouts.agent_infos
        assert np.array_equal(rollouts.agent_infos['latent_mean'][0],
                              latent_info['mean'])