예제 #1
0
def test_merged_configs():
    test_config = get_config(CFG_TEST)
    eqa_config = get_config(CFG_EQA)
    merged_config = get_config("{},{}".format(CFG_TEST, CFG_EQA))
    assert merged_config.TASK.TYPE == eqa_config.TASK.TYPE
    assert (merged_config.ENVIRONMENT.MAX_EPISODE_STEPS ==
            test_config.ENVIRONMENT.MAX_EPISODE_STEPS)
예제 #2
0
def test_new_keys_merged_configs():
    test_config = get_config(CFG_TEST)
    new_keys_config = get_config(CFG_NEW_KEYS)
    merged_config = get_config("{},{}".format(CFG_TEST, CFG_NEW_KEYS))
    assert (merged_config.TASK.MY_NEW_TASK_PARAM ==
            new_keys_config.TASK.MY_NEW_TASK_PARAM)
    assert (merged_config.ENVIRONMENT.MAX_EPISODE_STEPS ==
            test_config.ENVIRONMENT.MAX_EPISODE_STEPS)
예제 #3
0
def get_test_config(name: str):
    # use test dataset for lighter testing
    datapath = (
        "data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
    )
    cfg = get_config(name)
    if not os.path.exists(cfg.SIMULATOR.SCENE):
        pytest.skip("Please download Habitat test data to data folder.")
    if len(cfg.MULTI_TASK.TASKS) < 2:
        pytest.skip(
            "Please use a configuration with at least 2 tasks for testing."
        )
    cfg.defrost()
    cfg.DATASET.DATA_PATH = datapath
    cfg.DATASET.SPLIT = "test"
    # also make sure tasks config are overriden
    for task in cfg.MULTI_TASK.TASKS:
        task.DATASET.DATA_PATH = datapath
        task.DATASET.SPLIT = "test"
    # and work with small observations for testing
    if "RGB_SENSOR" in cfg:
        cfg.RGB_SENSOR.WIDTH = 64
        cfg.RGB_SENSOR.HEIGHT = 64
    if "DEPTH_SENSOR" in cfg:
        cfg.DEPTH_SENSOR.WIDTH = 64
        cfg.DEPTH_SENSOR.HEIGHT = 64
    cfg.freeze()
    return cfg
예제 #4
0
def test_overwrite_options():
    for steps_limit in range(MAX_TEST_STEPS_LIMIT):
        config = get_config(
            config_paths=CFG_TEST,
            opts=["ENVIRONMENT.MAX_EPISODE_STEPS", steps_limit],
        )
        assert (config.ENVIRONMENT.MAX_EPISODE_STEPS == steps_limit
                ), "Overwriting of config options failed."
예제 #5
0
def test_standard_config_compatibility():
    cfg = get_config("configs/tasks/pointnav.yaml")
    with MultiTaskEnv(config=cfg) as env:
        env.reset()
        actions = 0

        while not env.episode_over:
            # execute random action
            env.step(env.action_space.sample())
            actions += 1
        assert (
            actions >= 1
        ), "You should have performed at least one step with no interruptions"
예제 #6
0
def test_tasks_keep_defaults():
    defaults = _C.TASK.clone()
    cfg = get_config(MULITASK_TEST_FILENAME)
    cfg.defrost()
    cfg.MULTI_TASK.TASKS[0].TYPE = "MyCustomTestTask"
    cfg.freeze()
    assert (cfg.MULTI_TASK.TASKS[0].TYPE !=
            cfg.TASK.TYPE), "Each tasks property should be overridable"
    for k in defaults.keys():
        for task in cfg.MULTI_TASK.TASKS:
            assert (
                k
                in task), "Default property should be inherithed by each task"
예제 #7
0
def test_global_dataset_config():
    datatype = "MyDatasetType"
    config = open_yaml(MULITASK_TEST_FILENAME)
    for task in config["MULTI_TASK"]["TASKS"]:
        if "DATASET" in task:
            del task["DATASET"]

    config["DATASET"]["TYPE"] = datatype
    save_yaml("habitat_config_test.yaml", config)
    # load test config
    cfg = get_config("/tmp/habitat_config_test.yaml")
    # make sure each tasks has global dataset config
    for task in cfg.MULTI_TASK.TASKS:
        assert (
            task.DATASET.TYPE == cfg.DATASET.TYPE == datatype
        ), "Each task should inherit global dataset when dataset is not specified"
예제 #8
0
def test_global_dataset_config_override():
    datatype = "MyDatasetType"
    datapath = "/some/path/"
    config = open_yaml(MULITASK_TEST_FILENAME)
    for task in config["MULTI_TASK"]["TASKS"]:
        if "DATASET" in task:
            del task["DATASET"]
    # one tasks needs a different dataset
    config["MULTI_TASK"]["TASKS"][0]["DATASET"] = {
        "TYPE": datatype,
        "DATA_PATH": datapath,
    }
    save_yaml("habitat_config_test.yaml", config)
    # load test config
    cfg = get_config("/tmp/habitat_config_test.yaml")
    # make sure each tasks has global dataset config but the first one
    for i, task in enumerate(cfg.MULTI_TASK.TASKS):
        if i == 0:
            assert (task.DATASET.TYPE == datatype != cfg.DATASET.TYPE
                    ), "First task should have a different dataset"
        else:
            assert (
                task.DATASET.TYPE == cfg.DATASET.TYPE
            ), "Each task should inherit global dataset when dataset is not specified"
    if flag:
        cv2.imshow("RGB", rgb2bgr(obs["rgb"]))


if __name__ == "__main__":
    args = ArgumentParser()
    args.add_argument(
        "-i",
        "--interactive",
        help="Run demo interactively",
        action="store_true",
    )
    args = args.parse_args()
    ### One Env, many tasks ###
    # cfg = get_config('pointnav.yaml')
    cfg = get_config("configs/test/habitat_multitask_example.yaml")
    # cfg.defrost()
    # cfg.TASKS[0].SENSORS = ["PROXIMITY_SENSOR"]
    # cfg.freeze()
    with MultiTaskEnv(config=cfg) as env:
        print("{} episodes created from config file".format(len(
            env._episodes)))
        scene_sort_keys = {}
        for e in env.episodes:
            if e.scene_id not in scene_sort_keys:
                scene_sort_keys[e.scene_id] = len(scene_sort_keys)
        print("Number of scenes", scene_sort_keys)
        # usual OpenAI Gym-like env-agent loop
        n_episodes = 2
        print(env._tasks)
        taks_id = 0