Exemple #1
0
def check_rasterizer(cfg: dict, rasterizer: Rasterizer,
                     zarr_dataset: ChunkedDataset) -> None:
    frames = zarr_dataset.frames[:]  # Load all frames into memory
    for current_frame in [0, 50, len(frames) - 1]:
        history_num_frames = cfg["model_params"]["history_num_frames"]
        s = get_history_slice(current_frame,
                              history_num_frames,
                              1,
                              include_current_state=True)
        frames_to_rasterize = frames[s]
        agents = filter_agents_by_frames(frames_to_rasterize,
                                         zarr_dataset.agents)
        tl_faces = [np.empty(0, dtype=TL_FACE_DTYPE)
                    for _ in agents]  # TODO TR_FACES
        im = rasterizer.rasterize(frames_to_rasterize, agents, tl_faces)
        assert len(im.shape) == 3
        assert im.shape[-1] == rasterizer.num_channels()
        assert im.shape[:2] == tuple(cfg["raster_params"]["raster_size"])
        assert im.max() <= 1
        assert im.min() >= 0
        assert im.dtype == np.float32

        rgb_im = rasterizer.to_rgb(im)
        assert im.shape[:2] == rgb_im.shape[:2]
        assert rgb_im.shape[2] == 3  # RGB has three channels
        assert rgb_im.dtype == np.uint8
def check_rasterizer(cfg: dict, rasterizer: Rasterizer,
                     dataset: ChunkedStateDataset) -> None:
    frames = dataset.frames[:]  # Load all frames into memory
    for current_frame in [0, 50, len(frames) - 1]:
        history_num_frames = cfg["model_params"]["history_num_frames"]
        history_step_size = cfg["model_params"]["history_step_size"]
        s = get_history_slice(current_frame,
                              history_num_frames,
                              history_step_size,
                              include_current_state=True)
        frames_to_rasterize = frames[s]
        agents = filter_agents_by_frames(frames_to_rasterize, dataset.agents)

        im = rasterizer.rasterize(frames_to_rasterize, agents)
        assert len(im.shape) == 3
        assert im.shape[:2] == tuple(cfg["raster_params"]["raster_size"])
        assert im.shape[2] >= 3
        assert im.max() <= 1
        assert im.min() >= 0
        assert im.dtype == np.float32

        rgb_im = rasterizer.to_rgb(im)
        assert im.shape[:2] == rgb_im.shape[:2]
        assert rgb_im.shape[2] == 3  # RGB has three channels
        assert rgb_im.dtype == np.uint8
def test_history_slice() -> None:
    # Current index 10, 2 history states with step size 3
    # Should yield indices 4 and 7, so slice (7, 3, 3)
    assert get_history_slice(10, 2, 3) == slice(7, 3, -3)
    assert get_history_slice(10, 2, 3,
                             include_current_state=True) == slice(10, 3, -3)

    assert get_history_slice(20, 2, 3) == slice(17, 13, -3)
    assert get_history_slice(20, 3, 3) == slice(17, 10, -3)
    assert get_history_slice(10, 2, 1) == slice(9, 7, -1)

    assert get_history_slice(20, 2, 3,
                             include_current_state=True) == slice(20, 13, -3)
    assert get_history_slice(20, 3, 3,
                             include_current_state=True) == slice(20, 10, -3)
    assert get_history_slice(10, 2, 1,
                             include_current_state=True) == slice(10, 7, -1)

    # Not possible here to go past the first state, should give an empty slice
    # note range(10)[slice(0, 0, -3)] == []
    assert get_history_slice(1, 2, 3) == slice(0, 0, -3)
    assert get_history_slice(2, 2, 3) == slice(0, 0, -3)
    assert get_history_slice(0, 2, 3) == slice(0, 0, -3)

    # Partially possible here
    assert get_history_slice(3, 2, 3) == slice(0, None, -3)