Пример #1
0
def test_single_env_multi_goals_discrete(env_cls):
    env_cls_dict = {'wrapped': env_cls}
    env_args_kwargs = {'wrapped': dict(args=[], kwargs={'task_id': 1})}
    multi_task_env = MultiClassMultiTaskEnv(task_env_cls_dict=env_cls_dict,
                                            task_args_kwargs=env_args_kwargs,
                                            sample_goals=True,
                                            obs_type='with_goal_id')
    goals = multi_task_env.active_env.sample_goals_(2)
    assert len(goals) == 2
    goals_dict = {'wrapped': goals}
    multi_task_env.discretize_goal_space(goals_dict)

    assert multi_task_env._fully_discretized
    tasks_with_goals = multi_task_env.sample_tasks(2)
    for t in tasks_with_goals:
        assert 'task' in t
        assert 'goal' in t
    multi_task_env.set_task(tasks_with_goals[0])
    assert multi_task_env._active_task == tasks_with_goals[0]['task']
    reset_obs = multi_task_env.reset()
    step_obs, _, _, _ = multi_task_env.step(
        multi_task_env.action_space.sample())
    assert np.all(multi_task_env.observation_space.shape == reset_obs.shape)
    assert np.all(multi_task_env.observation_space.shape == step_obs.shape)
    assert reset_obs[multi_task_env._max_obs_dim:][
        env_args_kwargs['wrapped']['kwargs']['task_id'] +
        tasks_with_goals[0]['goal']] == 1
    assert step_obs[multi_task_env._max_obs_dim:][
        env_args_kwargs['wrapped']['kwargs']['task_id'] +
        tasks_with_goals[0]['goal']] == 1
    assert np.sum(reset_obs[multi_task_env._max_plain_dim:]) == 1
Пример #2
0
def test_multienv_single_goal(env_list):
    env_cls_dict = {
        'env-{}'.format(i): env_cls
        for i, env_cls in enumerate(env_list)
    }
    env_args_kwargs = {
        'env-{}'.format(i): dict(args=[], kwargs={'task_id': i})
        for i, _ in enumerate(env_list)
    }
    multi_task_env = MultiClassMultiTaskEnv(
        task_env_cls_dict=env_cls_dict,
        task_args_kwargs=env_args_kwargs,
        sample_goals=False,
        obs_type='with_goal_id',
        sample_all=True,
    )
    assert multi_task_env._fully_discretized

    n_tasks = len(env_list) * 2
    tasks = multi_task_env.sample_tasks(n_tasks)
    assert len(tasks) == n_tasks
    for t in tasks:
        multi_task_env.set_task(t)
        assert isinstance(multi_task_env.active_env,\
            env_cls_dict[multi_task_env._task_names[t % len(env_list)]])
Пример #3
0
def test_multitask_env_images(env_list):
    env_cls_dict = {
        'env-{}'.format(i): env_cls
        for i, env_cls in enumerate(env_list)
    }
    env_args_kwargs = {
        'env-{}'.format(i): dict(args=[],
                                 kwargs={
                                     'obs_type': 'plain',
                                     'task_id': i
                                 })
        for i, _ in enumerate(env_list)
    }
    multi_task_env = MultiClassMultiTaskEnv(
        task_env_cls_dict=env_cls_dict,
        task_args_kwargs=env_args_kwargs,
        sample_goals=False,
        obs_type='with_goal_and_id',
        sample_all=True,
    )
    assert multi_task_env._fully_discretized
    n_tasks = len(env_list)
    tasks = multi_task_env.sample_tasks(n_tasks)
    multi_task_env.set_task(tasks[0])
    multi_task_env.reset()
    img = multi_task_env.get_image(width=84, height=84)
    assert img.shape[0] == 84 and img.shape[1] == 84 and img.shape[2] == 3
Пример #4
0
def test_multienv_multigoals_fully_discretized(env_list):
    env_cls_dict = {
        'env-{}'.format(i): env_cls
        for i, env_cls in enumerate(env_list)
    }
    env_args_kwargs = {
        'env-{}'.format(i): dict(args=[],
                                 kwargs={
                                     'obs_type': 'plain',
                                     'task_id': i
                                 })
        for i, _ in enumerate(env_list)
    }
    multi_task_env = MultiClassMultiTaskEnv(
        task_env_cls_dict=env_cls_dict,
        task_args_kwargs=env_args_kwargs,
        sample_goals=True,
        obs_type='with_goal_and_id',
        sample_all=False,
    )
    goals_dict = dict()
    for i in range(len(multi_task_env._task_envs)):
        goals = multi_task_env.active_env.sample_goals_(i + 1)
        goals_dict['env-{}'.format(i)] = goals
    multi_task_env.discretize_goal_space(goals_dict)
    assert multi_task_env._fully_discretized

    tasks_with_goals = multi_task_env.sample_tasks(2)
    for t in tasks_with_goals:
        assert 'task' in t
        assert 'goal' in t
    multi_task_env.set_task(tasks_with_goals[0])
    assert multi_task_env._active_task == tasks_with_goals[0]['task']

    # check task id
    reset_obs = multi_task_env.reset()
    step_obs, _, _, _ = multi_task_env.step(
        multi_task_env.action_space.sample())
    assert np.all(multi_task_env.observation_space.shape == reset_obs.shape)
    assert np.all(multi_task_env.observation_space.shape == step_obs.shape)

    task_name = multi_task_env._task_names[tasks_with_goals[0]['task']]
    goal = tasks_with_goals[0]['goal']
    plain_dim = multi_task_env._max_obs_dim
    goal_dim = 3
    task_start_index = goal_dim + multi_task_env.active_task
    # TODO these dims are ugly... rewrite assertion later
    assert reset_obs[plain_dim:][task_start_index] == 1, reset_obs
    assert step_obs[plain_dim:][task_start_index] == 1, step_obs
    assert np.sum(
        reset_obs[plain_dim + task_start_index:plain_dim + task_start_index +
                  multi_task_env._n_discrete_goals]) == 1
    assert np.sum(
        reset_obs[plain_dim + task_start_index:plain_dim + task_start_index +
                  multi_task_env._n_discrete_goals]) == 1
Пример #5
0
def test_ml3():
    multi_task_env = MultiClassMultiTaskEnv(
        task_env_cls_dict=ml3_env_cls_dict,
        task_args_kwargs=ml3_env_args_kwargs,
        sample_goals=True,
        obs_type='plain',
    )
    for _ in range(2):
        tasks = multi_task_env.sample_tasks(3)
        assert len(tasks) == 3
        for t in tasks:
            assert 'task' in t.keys()
            assert 'goal' in t.keys()
            multi_task_env.set_task(t)
            _ = multi_task_env.reset()
            goal = multi_task_env.active_env.goal
            assert multi_task_env.active_env.goal_space.contains(goal)
Пример #6
0
def test_reach_push_pick_place(env_cls):

    task_types = ['pick_place', 'reach', 'push']
    env_dict = {t: env_cls for t in task_types}
    env_args_kwargs = {
        t: dict(args=[],
                kwargs={
                    'task_type': t,
                    'obs_type': 'plain',
                    'task_id': 1
                })
        for t in task_types
    }

    multi_task_env = MultiClassMultiTaskEnv(
        task_env_cls_dict=env_dict,
        task_args_kwargs=env_args_kwargs,
        obs_type='with_goal_id',
        sample_goals=True,  # Each environment should still sample only
        # one goal since each of them is discrete goal
        # space and contains only one goal.
        sample_all=True,
    )
    goals_dict = {
        'pick_place': [np.array([0.1, 0.8, 0.2])],
        'reach': [np.array([-0.1, 0.8, 0.2])],
        'push': [np.array([0.1, 0.8, 0.02])],
    }
    multi_task_env.discretize_goal_space(goals_dict)
    assert multi_task_env._fully_discretized

    n_tasks = len(env_dict.keys())
    # do this test twice to make sure multiple sampling is working
    for _ in range(2):
        tasks = multi_task_env.sample_tasks(n_tasks)
        assert len(tasks) == n_tasks
        for t in tasks:
            assert 'task' in t.keys()
            assert 'goal' in t.keys()
            multi_task_env.set_task(t)
            _ = multi_task_env.reset()
            task_name = multi_task_env._task_names[t['task']]
            goal = multi_task_env.active_env.goal
            assert np.array_equal(goal, goals_dict[task_name][0])
            assert multi_task_env.active_env.task_type == task_name
Пример #7
0
 def set_task(self, task):
     return MultiClassMultiTaskEnv.set_task(self, task)