예제 #1
0
def build_rasterizer(cfg: dict,
                     data_manager: DataManager,
                     debug: bool = False) -> Rasterizer:
    raster_cfg = cfg["raster_params"]
    map_type = raster_cfg["map_type"]

    if map_type == "semantic_graph":
        dataset_meta_key = raster_cfg["dataset_meta_key"]
        filter_agents_threshold = raster_cfg["filter_agents_threshold"]
        history_num_frames = cfg["model_params"]["history_num_frames"]

        render_context = RenderContext(
            raster_size_px=np.array(raster_cfg["raster_size"]),
            pixel_size_m=np.array(raster_cfg["pixel_size"]),
            center_in_raster_ratio=np.array(raster_cfg["ego_center"]),
        )

        semantic_map_filepath = data_manager.require(
            raster_cfg["semantic_map_key"])
        dataset_meta = _load_metadata(dataset_meta_key, data_manager)
        world_to_ecef = np.array(dataset_meta["world_to_ecef"],
                                 dtype=np.float64)

        return SemGraphRasterizer(render_context, filter_agents_threshold,
                                  history_num_frames, semantic_map_filepath,
                                  world_to_ecef, debug)
    else:
        return l5kit_build_rasterizer(cfg, data_manager)
예제 #2
0
    def from_cfg(data_manager: DataManager, cfg: dict) -> "MapAPI":
        """Build a MapAPI object starting from a config file and a data manager

        :param data_manager: a data manager object ot resolve paths
        :param cfg: the config dict
        :return: a MapAPI object
        """
        raster_cfg = cfg["raster_params"]
        dataset_meta_key = raster_cfg["dataset_meta_key"]

        semantic_map_filepath = data_manager.require(
            raster_cfg["semantic_map_key"])
        dataset_meta = load_metadata(data_manager.require(dataset_meta_key))
        world_to_ecef = np.array(dataset_meta["world_to_ecef"],
                                 dtype=np.float64)

        return MapAPI(semantic_map_filepath, world_to_ecef)
예제 #3
0
def _load_metadata(meta_key: str, data_manager: DataManager) -> dict:
    """
    Load a json metadata file

    Args:
        meta_key (str): relative key to the metadata
        data_manager (DataManager): DataManager used for requiring files

    Returns:
        dict: metadata as a dict
    """
    metadata_path = data_manager.require(meta_key)
    with open(metadata_path, "r") as f:
        metadata: dict = json.load(f)
    return metadata
예제 #4
0
def _load_satellite_map(image_key: str,
                        data_manager: DataManager) -> np.ndarray:
    """Loads image from given key.

    Args:
        image_key (str): key to the image (e.g. ``maps/my_satellite_image.png``)
        data_manager (DataManager): DataManager used for requiring files

    Returns:
        np.ndarry: Image
    """

    image_path = data_manager.require(image_key)
    image = cv2.imread(image_path)[..., ::-1]  # BGR->RGB
    if image is None:
        raise Exception(f"Failed to load image from {image_path}")

    return image
예제 #5
0
def build_rasterizer_tl(cfg: dict, data_manager: DataManager) -> Rasterizer:
    """Factory function for rasterizers, reads the config, loads required data and initializes the correct rasterizer.

    Args:
        cfg (dict): Config.
        data_manager (DataManager): Datamanager that is used to require files to be present.

    Raises:
        NotImplementedError: Thrown when the ``map_type`` read from the config doesn't have an associated rasterizer
        type in this factory function. If you have custom rasterizers, you can wrap this function in your own factory
        function and catch this error.

    Returns:
        Rasterizer: Rasterizer initialized given the supplied config.
    """
    raster_cfg = cfg["raster_params"]
    map_type = raster_cfg["map_type"]
    dataset_meta_key = raster_cfg["dataset_meta_key"]

    render_context = RenderContext(
        raster_size_px=np.array(raster_cfg["raster_size"]),
        pixel_size_m=np.array(raster_cfg["pixel_size"]),
        center_in_raster_ratio=np.array(raster_cfg["ego_center"]),
    )
    filter_agents_threshold = raster_cfg["filter_agents_threshold"]
    history_num_frames = cfg["model_params"]["history_num_frames"]

    semantic_map_filepath = data_manager.require(raster_cfg["semantic_map_key"])
    try:
        dataset_meta = _load_metadata(dataset_meta_key, data_manager)
        world_to_ecef = np.array(dataset_meta["world_to_ecef"], dtype=np.float64)
    except (KeyError, FileNotFoundError):  # TODO remove when new dataset version is available
        world_to_ecef = get_hardcoded_world_to_ecef()

    return SemBoxTLRasterizer(render_context, filter_agents_threshold, history_num_frames, semantic_map_filepath,
                              world_to_ecef)
예제 #6
0
def build_dataloader(
    cfg: Dict,
    split: str,
    data_manager: DataManager,
    dataset_class: Callable,
    rasterizer: Rasterizer,
    perturbation: Optional[Perturbation] = None,
) -> DataLoader:
    """
    Util function to build a dataloader from a dataset of dataset_class. Note we have to pass rasterizer and
    perturbation as the factory functions for those are likely to change between repos.

    Args:
        cfg (dict): configuration dict
        split (str): this will be used to index the cfg to get the correct datasets (train or val currently)
        data_manager (DataManager): manager for resolving paths
        dataset_class (Callable): a class object (EgoDataset or AgentDataset currently) to build the dataset
        rasterizer (Rasterizer): the rasterizer for the dataset
        perturbation (Optional[Perturbation]): an optional perturbation object

    Returns:
        DataLoader: pytorch Dataloader object built with Concat and Sub datasets
    """

    data_loader_cfg = cfg[f"{split}_data_loader"]
    datasets = []
    for dataset_param in data_loader_cfg["datasets"]:
        zarr_dataset_path = data_manager.require(key=dataset_param["key"])
        zarr_dataset = ChunkedStateDataset(path=zarr_dataset_path)
        zarr_dataset.open()
        zarr_dataset.scenes = get_combined_scenes(zarr_dataset.scenes)

        #  Let's load the zarr dataset with our dataset.
        dataset = dataset_class(cfg,
                                zarr_dataset,
                                rasterizer,
                                perturbation=perturbation)

        scene_indices = dataset_param["scene_indices"]
        scene_subsets = []

        if dataset_param["scene_indices"][0] == -1:  # TODO replace with empty
            scene_subset = Subset(dataset, np.arange(0, len(dataset)))
            scene_subsets.append(scene_subset)
        else:
            for scene_idx in scene_indices:
                valid_indices = dataset.get_scene_indices(scene_idx)
                scene_subset = Subset(dataset, valid_indices)
                scene_subsets.append(scene_subset)

        datasets.extend(scene_subsets)

    #  Let's concatenate the training scenes into one dataset for the data loader to load from.
    concat_dataset: ConcatDataset = ConcatDataset(datasets)

    #  Initialize the data loader that our training loop will iterate on.
    batch_size = data_loader_cfg["batch_size"]
    shuffle = data_loader_cfg["shuffle"]
    num_workers = data_loader_cfg["num_workers"]
    dataloader = DataLoader(dataset=concat_dataset,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            num_workers=num_workers)

    return dataloader
예제 #7
0
def build_rasterizer(cfg: dict, data_manager: DataManager) -> Rasterizer:
    """Factory function for rasterizers, reads the config, loads required data and initializes the correct rasterizer.

    Args:
        cfg (dict): Config.
        data_manager (DataManager): Datamanager that is used to require files to be present.

    Raises:
        NotImplementedError: Thrown when the ``map_type`` read from the config doesn't have an associated rasterizer
        type in this factory function. If you have custom rasterizers, you can wrap this function in your own factory
        function and catch this error.

    Returns:
        Rasterizer: Rasterizer initialized given the supplied config.
    """
    raster_cfg = cfg["raster_params"]
    map_type = raster_cfg["map_type"]
    dataset_meta_key = raster_cfg["dataset_meta_key"]

    render_context = RenderContext(
        raster_size_px=np.array(raster_cfg["raster_size"]),
        pixel_size_m=np.array(raster_cfg["pixel_size"]),
        center_in_raster_ratio=np.array(raster_cfg["ego_center"]),
    )

    filter_agents_threshold = raster_cfg["filter_agents_threshold"]
    history_num_frames = cfg["model_params"]["history_num_frames"]

    if map_type in ["py_satellite", "satellite_debug"]:
        pass
        # sat_image = _load_satellite_map(raster_cfg["satellite_map_key"], data_manager)
        #
        # try:
        #     dataset_meta = _load_metadata(dataset_meta_key, data_manager)
        #     world_to_ecef = np.array(dataset_meta["world_to_ecef"], dtype=np.float64)
        #     ecef_to_aerial = np.array(dataset_meta["ecef_to_aerial"], dtype=np.float64)
        #
        # except (KeyError, FileNotFoundError):  # TODO remove when new dataset version is available
        #     world_to_ecef = get_hardcoded_world_to_ecef()
        #     ecef_to_aerial = get_hardcoded_ecef_to_aerial()
        #
        # world_to_aerial = np.matmul(ecef_to_aerial, world_to_ecef)
        # if map_type == "py_satellite":
        #     return SatBoxRasterizer(
        #         render_context, filter_agents_threshold, history_num_frames, sat_image, world_to_aerial,
        #     )
        # else:
        #     return SatelliteRasterizer(render_context, sat_image, world_to_aerial)

    elif map_type in ["py_semantic", "semantic_debug"]:
        semantic_map_filepath = data_manager.require(
            raster_cfg["semantic_map_key"])
        try:
            dataset_meta = _load_metadata(dataset_meta_key, data_manager)
            world_to_ecef = np.array(dataset_meta["world_to_ecef"],
                                     dtype=np.float64)
        except (KeyError, FileNotFoundError
                ):  # TODO remove when new dataset version is available
            world_to_ecef = get_hardcoded_world_to_ecef()
        if map_type == "py_semantic":

            # return SemBoxRasterizer(
            return SemBoxRasterizerCompressed(
                render_context,
                filter_agents_threshold,
                history_num_frames,
                semantic_map_filepath,
                world_to_ecef,
            )
        else:
            return SemanticRasterizer(render_context, semantic_map_filepath,
                                      world_to_ecef)

    # elif map_type == "box_debug":
    #     return BoxRasterizer(render_context, filter_agents_threshold, history_num_frames)
    # elif map_type == "stub_debug":
    #     return StubRasterizer(render_context, filter_agents_threshold)
    else:
        raise NotImplementedError(
            f"Rasterizer for map type {map_type} is not supported.")