def test_get_scene_dataset(dmg: LocalDataManager, tmp_path: Path, zarr_dataset: ChunkedDataset) -> None: concat_count = 4 zarr_input_path = dmg.require("single_scene.zarr") zarr_output_path = str(tmp_path / f"{uuid4()}.zarr") zarr_concat([zarr_input_path] * concat_count, zarr_output_path) zarr_cat_dataset = ChunkedDataset(zarr_output_path) zarr_cat_dataset.open() # all scenes should be the same as the input one for scene_idx in range(concat_count): zarr_scene = zarr_cat_dataset.get_scene_dataset(scene_idx) assert np.alltrue(zarr_scene.scenes == np.asarray(zarr_dataset.scenes)) assert np.alltrue(zarr_scene.frames == np.asarray(zarr_dataset.frames)) assert np.alltrue(zarr_scene.agents == np.asarray(zarr_dataset.agents)) assert np.alltrue( zarr_scene.tl_faces == np.asarray(zarr_dataset.tl_faces)) with pytest.raises(ValueError): zarr_cat_dataset.get_scene_dataset(concat_count + 1)
def test_dataset_frames_subset(zarr_dataset: ChunkedDataset) -> None: zarr_dataset = zarr_dataset.get_scene_dataset(0) frame_start = 10 frame_end = 25 zarr_cut = get_frames_subset(zarr_dataset, frame_start, frame_end) assert len(zarr_cut.scenes) == 1 assert len(zarr_cut.frames) == frame_end - frame_start assert np.all( zarr_cut.frames["ego_translation"] == zarr_dataset.frames["ego_translation"][frame_start:frame_end] ) agents_slice = get_agents_slice_from_frames( *zarr_dataset.frames[[frame_start, frame_end - 1]] ) tls_slice = get_tl_faces_slice_from_frames( *zarr_dataset.frames[[frame_start, frame_end - 1]] ) assert np.all(zarr_cut.agents == zarr_dataset.agents[agents_slice]) assert np.all(zarr_cut.tl_faces == zarr_dataset.tl_faces[tls_slice]) assert np.all(zarr_cut.scenes["frame_index_interval"] == (0, len(zarr_cut.frames)))