def test_cubemap_stiching(test_cfg_path: str, mode: str, camera: str):
    meta_config = get_config(config_paths=test_cfg_path)
    meta_config.defrost()
    config = meta_config.TASK_CONFIG
    CAMERA_NUM = 6
    orient = [
        [0, math.pi, 0],  # Back
        [-math.pi / 2, 0, 0],  # Down
        [0, 0, 0],  # Front
        [0, math.pi / 2, 0],  # Right
        [0, 3 / 2 * math.pi, 0],  # Left
        [math.pi / 2, 0, 0],  # Up
    ]
    sensor_uuids = []

    if "RGB_SENSOR" in config.SIMULATOR.AGENT_0.SENSORS:
        config.SIMULATOR.RGB_SENSOR.ORIENTATION = orient[0]
        for camera_id in range(1, CAMERA_NUM):
            camera_template = f"RGB_{camera_id}"
            camera_config = deepcopy(config.SIMULATOR.RGB_SENSOR)
            camera_config.ORIENTATION = orient[camera_id]

            camera_config.UUID = camera_template.lower()
            sensor_uuids.append(camera_config.UUID)
            setattr(config.SIMULATOR, camera_template, camera_config)
            config.SIMULATOR.AGENT_0.SENSORS.append(camera_template)

    if "DEPTH_SENSOR" in config.SIMULATOR.AGENT_0.SENSORS:
        config.SIMULATOR.DEPTH_SENSOR.ORIENTATION = orient[0]
        for camera_id in range(1, CAMERA_NUM):
            camera_template = f"DEPTH_{camera_id}"
            camera_config = deepcopy(config.SIMULATOR.DEPTH_SENSOR)
            camera_config.ORIENTATION = orient[camera_id]
            camera_config.UUID = camera_template.lower()
            sensor_uuids.append(camera_config.UUID)

            setattr(config.SIMULATOR, camera_template, camera_config)
            config.SIMULATOR.AGENT_0.SENSORS.append(camera_template)

    meta_config.TASK_CONFIG = config
    meta_config.SENSORS = config.SIMULATOR.AGENT_0.SENSORS
    if camera == "equirec":
        meta_config.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.SENSOR_UUIDS = tuple(
            sensor_uuids)
    elif camera == "fisheye":
        meta_config.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH.SENSOR_UUIDS = tuple(
            sensor_uuids)
    meta_config.freeze()
    execute_exp(meta_config, mode)
    # Deinit processes group
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
Beispiel #2
0
def test_cubemap_stiching(test_cfg_path: str, mode: str, camera: str,
                          sensor_type: str):
    meta_config = get_config(config_paths=test_cfg_path)
    meta_config.defrost()
    config = meta_config.TASK_CONFIG
    CAMERA_NUM = 6
    orient = [
        [0, math.pi, 0],  # Back
        [-math.pi / 2, 0, 0],  # Down
        [0, 0, 0],  # Front
        [0, math.pi / 2, 0],  # Right
        [0, 3 / 2 * math.pi, 0],  # Left
        [math.pi / 2, 0, 0],  # Up
    ]
    sensor_uuids = []

    if f"{sensor_type}_SENSOR" not in config.SIMULATOR.AGENT_0.SENSORS:
        config.SIMULATOR.AGENT_0.SENSORS.append(f"{sensor_type}_SENSOR")
    sensor = getattr(config.SIMULATOR, f"{sensor_type}_SENSOR")
    for camera_id in range(CAMERA_NUM):
        camera_template = f"{sensor_type}_{camera_id}"
        camera_config = deepcopy(sensor)
        camera_config.ORIENTATION = orient[camera_id]
        camera_config.UUID = camera_template.lower()
        sensor_uuids.append(camera_config.UUID)
        setattr(config.SIMULATOR, camera_template, camera_config)
        config.SIMULATOR.AGENT_0.SENSORS.append(camera_template)

    meta_config.TASK_CONFIG = config
    meta_config.SENSORS = config.SIMULATOR.AGENT_0.SENSORS
    if camera == "equirect":
        meta_config.RL.POLICY.OBS_TRANSFORMS.CUBE2EQ.SENSOR_UUIDS = tuple(
            sensor_uuids)
    elif camera == "fisheye":
        meta_config.RL.POLICY.OBS_TRANSFORMS.CUBE2FISH.SENSOR_UUIDS = tuple(
            sensor_uuids)
    meta_config.freeze()
    if camera in ["equirect", "fisheye"]:
        execute_exp(meta_config, mode)
        # Deinit processes group
        if torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()

    elif camera == "cubemap":
        # 1) Generate an equirect image from cubemap images.
        # 2) Generate cubemap images from the equirect image.
        # 3) Compare the input and output cubemap
        env_fn_args = []
        for split in ["train", "val"]:
            tmp_config = config.clone()
            tmp_config.defrost()
            tmp_config.DATASET["SPLIT"] = split
            tmp_config.freeze()
            env_fn_args.append((tmp_config, None))

        with VectorEnv(env_fn_args=env_fn_args) as envs:
            observations = envs.reset()
        batch = batch_obs(observations)
        orig_batch = deepcopy(batch)

        #  ProjectionTransformer
        obs_trans_to_eq = baseline_registry.get_obs_transformer(
            "CubeMap2Equirect")
        cube2equirect = obs_trans_to_eq(sensor_uuids, (256, 512))
        obs_trans_to_cube = baseline_registry.get_obs_transformer(
            "Equirect2CubeMap")
        equirect2cube = obs_trans_to_cube(cube2equirect.target_uuids,
                                          (256, 256))

        # Cubemap to Equirect to Cubemap
        batch_eq = cube2equirect(batch)
        batch_cube = equirect2cube(batch_eq)

        # Extract input and output cubemap
        output_cube = batch_cube[cube2equirect.target_uuids[0]]
        input_cube = [orig_batch[key] for key in sensor_uuids]
        input_cube = torch.stack(input_cube, axis=1)
        input_cube = torch.flatten(input_cube, end_dim=1)

        # Apply blur to absorb difference (blur, etc.) caused by conversion
        if sensor_type == "RGB":
            output_cube = output_cube.float() / 255
            input_cube = input_cube.float() / 255
        output_cube = output_cube.permute((0, 3, 1, 2))  # NHWC => NCHW
        input_cube = input_cube.permute((0, 3, 1, 2))  # NHWC => NCHW
        apply_blur = torch.nn.AvgPool2d(5, 3, 2)
        output_cube = apply_blur(output_cube)
        input_cube = apply_blur(input_cube)

        # Calculate the difference
        diff = torch.abs(output_cube - input_cube)
        assert diff.mean().item() < 0.01
    else:
        raise ValueError(f"Unknown camera name: {camera}")