def test_merged_configs(): test_config = get_config(CFG_TEST) eqa_config = get_config(CFG_EQA) merged_config = get_config("{},{}".format(CFG_TEST, CFG_EQA)) assert merged_config.TASK.TYPE == eqa_config.TASK.TYPE assert (merged_config.ENVIRONMENT.MAX_EPISODE_STEPS == test_config.ENVIRONMENT.MAX_EPISODE_STEPS)
def test_new_keys_merged_configs(): test_config = get_config(CFG_TEST) new_keys_config = get_config(CFG_NEW_KEYS) merged_config = get_config("{},{}".format(CFG_TEST, CFG_NEW_KEYS)) assert (merged_config.TASK.MY_NEW_TASK_PARAM == new_keys_config.TASK.MY_NEW_TASK_PARAM) assert (merged_config.ENVIRONMENT.MAX_EPISODE_STEPS == test_config.ENVIRONMENT.MAX_EPISODE_STEPS)
def get_test_config(name: str): # use test dataset for lighter testing datapath = ( "data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz" ) cfg = get_config(name) if not os.path.exists(cfg.SIMULATOR.SCENE): pytest.skip("Please download Habitat test data to data folder.") if len(cfg.MULTI_TASK.TASKS) < 2: pytest.skip( "Please use a configuration with at least 2 tasks for testing." ) cfg.defrost() cfg.DATASET.DATA_PATH = datapath cfg.DATASET.SPLIT = "test" # also make sure tasks config are overriden for task in cfg.MULTI_TASK.TASKS: task.DATASET.DATA_PATH = datapath task.DATASET.SPLIT = "test" # and work with small observations for testing if "RGB_SENSOR" in cfg: cfg.RGB_SENSOR.WIDTH = 64 cfg.RGB_SENSOR.HEIGHT = 64 if "DEPTH_SENSOR" in cfg: cfg.DEPTH_SENSOR.WIDTH = 64 cfg.DEPTH_SENSOR.HEIGHT = 64 cfg.freeze() return cfg
def test_overwrite_options(): for steps_limit in range(MAX_TEST_STEPS_LIMIT): config = get_config( config_paths=CFG_TEST, opts=["ENVIRONMENT.MAX_EPISODE_STEPS", steps_limit], ) assert (config.ENVIRONMENT.MAX_EPISODE_STEPS == steps_limit ), "Overwriting of config options failed."
def test_standard_config_compatibility(): cfg = get_config("configs/tasks/pointnav.yaml") with MultiTaskEnv(config=cfg) as env: env.reset() actions = 0 while not env.episode_over: # execute random action env.step(env.action_space.sample()) actions += 1 assert ( actions >= 1 ), "You should have performed at least one step with no interruptions"
def test_tasks_keep_defaults(): defaults = _C.TASK.clone() cfg = get_config(MULITASK_TEST_FILENAME) cfg.defrost() cfg.MULTI_TASK.TASKS[0].TYPE = "MyCustomTestTask" cfg.freeze() assert (cfg.MULTI_TASK.TASKS[0].TYPE != cfg.TASK.TYPE), "Each tasks property should be overridable" for k in defaults.keys(): for task in cfg.MULTI_TASK.TASKS: assert ( k in task), "Default property should be inherithed by each task"
def test_global_dataset_config(): datatype = "MyDatasetType" config = open_yaml(MULITASK_TEST_FILENAME) for task in config["MULTI_TASK"]["TASKS"]: if "DATASET" in task: del task["DATASET"] config["DATASET"]["TYPE"] = datatype save_yaml("habitat_config_test.yaml", config) # load test config cfg = get_config("/tmp/habitat_config_test.yaml") # make sure each tasks has global dataset config for task in cfg.MULTI_TASK.TASKS: assert ( task.DATASET.TYPE == cfg.DATASET.TYPE == datatype ), "Each task should inherit global dataset when dataset is not specified"
def test_global_dataset_config_override(): datatype = "MyDatasetType" datapath = "/some/path/" config = open_yaml(MULITASK_TEST_FILENAME) for task in config["MULTI_TASK"]["TASKS"]: if "DATASET" in task: del task["DATASET"] # one tasks needs a different dataset config["MULTI_TASK"]["TASKS"][0]["DATASET"] = { "TYPE": datatype, "DATA_PATH": datapath, } save_yaml("habitat_config_test.yaml", config) # load test config cfg = get_config("/tmp/habitat_config_test.yaml") # make sure each tasks has global dataset config but the first one for i, task in enumerate(cfg.MULTI_TASK.TASKS): if i == 0: assert (task.DATASET.TYPE == datatype != cfg.DATASET.TYPE ), "First task should have a different dataset" else: assert ( task.DATASET.TYPE == cfg.DATASET.TYPE ), "Each task should inherit global dataset when dataset is not specified"
if flag: cv2.imshow("RGB", rgb2bgr(obs["rgb"])) if __name__ == "__main__": args = ArgumentParser() args.add_argument( "-i", "--interactive", help="Run demo interactively", action="store_true", ) args = args.parse_args() ### One Env, many tasks ### # cfg = get_config('pointnav.yaml') cfg = get_config("configs/test/habitat_multitask_example.yaml") # cfg.defrost() # cfg.TASKS[0].SENSORS = ["PROXIMITY_SENSOR"] # cfg.freeze() with MultiTaskEnv(config=cfg) as env: print("{} episodes created from config file".format(len( env._episodes))) scene_sort_keys = {} for e in env.episodes: if e.scene_id not in scene_sort_keys: scene_sort_keys[e.scene_id] = len(scene_sort_keys) print("Number of scenes", scene_sort_keys) # usual OpenAI Gym-like env-agent loop n_episodes = 2 print(env._tasks) taks_id = 0