Exemplo n.º 1
0
def check_rasterizer(cfg: dict, rasterizer: Rasterizer,
                     zarr_dataset: ChunkedDataset) -> None:
    frames = zarr_dataset.frames[:]  # Load all frames into memory
    for current_frame in [0, 50, len(frames) - 1]:
        history_num_frames = cfg["model_params"]["history_num_frames"]
        s = get_history_slice(current_frame,
                              history_num_frames,
                              1,
                              include_current_state=True)
        frames_to_rasterize = frames[s]
        agents = filter_agents_by_frames(frames_to_rasterize,
                                         zarr_dataset.agents)
        tl_faces = [np.empty(0, dtype=TL_FACE_DTYPE)
                    for _ in agents]  # TODO TR_FACES
        im = rasterizer.rasterize(frames_to_rasterize, agents, tl_faces)
        assert len(im.shape) == 3
        assert im.shape[-1] == rasterizer.num_channels()
        assert im.shape[:2] == tuple(cfg["raster_params"]["raster_size"])
        assert im.max() <= 1
        assert im.min() >= 0
        assert im.dtype == np.float32

        rgb_im = rasterizer.to_rgb(im)
        assert im.shape[:2] == rgb_im.shape[:2]
        assert rgb_im.shape[2] == 3  # RGB has three channels
        assert rgb_im.dtype == np.uint8
Exemplo n.º 2
0
def check_rasterizer(cfg: dict, rasterizer: Rasterizer,
                     dataset: ChunkedStateDataset) -> None:
    frames = dataset.frames[:]  # Load all frames into memory
    for current_frame in [0, 50, len(frames) - 1]:
        history_num_frames = cfg["model_params"]["history_num_frames"]
        history_step_size = cfg["model_params"]["history_step_size"]
        s = get_history_slice(current_frame,
                              history_num_frames,
                              history_step_size,
                              include_current_state=True)
        frames_to_rasterize = frames[s]
        agents = filter_agents_by_frames(frames_to_rasterize, dataset.agents)

        im = rasterizer.rasterize(frames_to_rasterize, agents)
        assert len(im.shape) == 3
        assert im.shape[:2] == tuple(cfg["raster_params"]["raster_size"])
        assert im.shape[2] >= 3
        assert im.max() <= 1
        assert im.min() >= 0
        assert im.dtype == np.float32

        rgb_im = rasterizer.to_rgb(im)
        assert im.shape[:2] == rgb_im.shape[:2]
        assert rgb_im.shape[2] == 3  # RGB has three channels
        assert rgb_im.dtype == np.uint8
Exemplo n.º 3
0
def save_input_raster(rasterizer: Rasterizer, image: torch.Tensor, num_images: int = 20,
                      output_folder: str = 'raster_inputs') -> None:
    """Save the input raster image.

    :param rasterizer: the rasterizer
    :param image: numpy array
    :param num_images: number of images to save
    :param output_folder: directory to save the image
    :return: the numpy dict with 'positions' and 'yaws'
    """

    image = image.permute(1, 2, 0).cpu().numpy()
    output_im = rasterizer.to_rgb(image)

    im = Image.fromarray(output_im)

    # mkdir
    Path(output_folder).mkdir(exist_ok=True)
    output_folder = Path(output_folder)

    # loop
    i = 0
    img_path = output_folder / f"input{i}.png"
    while img_path.exists():
        i += 1
        img_path = output_folder / f"input{i}.png"

    # save
    im.save(img_path)

    # exit code once num_images images saved
    if i == num_images:
        exit()