def check_rasterizer(cfg: dict, rasterizer: Rasterizer, zarr_dataset: ChunkedDataset) -> None: frames = zarr_dataset.frames[:] # Load all frames into memory for current_frame in [0, 50, len(frames) - 1]: history_num_frames = cfg["model_params"]["history_num_frames"] s = get_history_slice(current_frame, history_num_frames, 1, include_current_state=True) frames_to_rasterize = frames[s] agents = filter_agents_by_frames(frames_to_rasterize, zarr_dataset.agents) tl_faces = [np.empty(0, dtype=TL_FACE_DTYPE) for _ in agents] # TODO TR_FACES im = rasterizer.rasterize(frames_to_rasterize, agents, tl_faces) assert len(im.shape) == 3 assert im.shape[-1] == rasterizer.num_channels() assert im.shape[:2] == tuple(cfg["raster_params"]["raster_size"]) assert im.max() <= 1 assert im.min() >= 0 assert im.dtype == np.float32 rgb_im = rasterizer.to_rgb(im) assert im.shape[:2] == rgb_im.shape[:2] assert rgb_im.shape[2] == 3 # RGB has three channels assert rgb_im.dtype == np.uint8
def check_rasterizer(cfg: dict, rasterizer: Rasterizer, dataset: ChunkedStateDataset) -> None: frames = dataset.frames[:] # Load all frames into memory for current_frame in [0, 50, len(frames) - 1]: history_num_frames = cfg["model_params"]["history_num_frames"] history_step_size = cfg["model_params"]["history_step_size"] s = get_history_slice(current_frame, history_num_frames, history_step_size, include_current_state=True) frames_to_rasterize = frames[s] agents = filter_agents_by_frames(frames_to_rasterize, dataset.agents) im = rasterizer.rasterize(frames_to_rasterize, agents) assert len(im.shape) == 3 assert im.shape[:2] == tuple(cfg["raster_params"]["raster_size"]) assert im.shape[2] >= 3 assert im.max() <= 1 assert im.min() >= 0 assert im.dtype == np.float32 rgb_im = rasterizer.to_rgb(im) assert im.shape[:2] == rgb_im.shape[:2] assert rgb_im.shape[2] == 3 # RGB has three channels assert rgb_im.dtype == np.uint8
def test_history_slice() -> None: # Current index 10, 2 history states with step size 3 # Should yield indices 4 and 7, so slice (7, 3, 3) assert get_history_slice(10, 2, 3) == slice(7, 3, -3) assert get_history_slice(10, 2, 3, include_current_state=True) == slice(10, 3, -3) assert get_history_slice(20, 2, 3) == slice(17, 13, -3) assert get_history_slice(20, 3, 3) == slice(17, 10, -3) assert get_history_slice(10, 2, 1) == slice(9, 7, -1) assert get_history_slice(20, 2, 3, include_current_state=True) == slice(20, 13, -3) assert get_history_slice(20, 3, 3, include_current_state=True) == slice(20, 10, -3) assert get_history_slice(10, 2, 1, include_current_state=True) == slice(10, 7, -1) # Not possible here to go past the first state, should give an empty slice # note range(10)[slice(0, 0, -3)] == [] assert get_history_slice(1, 2, 3) == slice(0, 0, -3) assert get_history_slice(2, 2, 3) == slice(0, 0, -3) assert get_history_slice(0, 2, 3) == slice(0, 0, -3) # Partially possible here assert get_history_slice(3, 2, 3) == slice(0, None, -3)