コード例 #1
0
def make_dataset(
    args, sampler: Optional[FrameSampler] = None, transform=None
) -> VideoDataset:
    if sampler is None:
        sampler = FullVideoSampler()
    dataset_type = args.dataset_type.lower()
    if dataset_type == "gulp":
        if transform is not None:
            transform = Compose([TimeApply(ToPILImage()), transform])
        return GulpVideoDataset(
            args.dataset_root,
            label_set=DummyLabelSet(),
            sampler=sampler,
            transform=transform,
        )
    elif dataset_type == "image":
        return ImageFolderVideoDataset(
            args.dataset_root,
            args.image_filename_template,
            label_set=DummyLabelSet(),
            sampler=sampler,
            transform=transform,
        )
    elif dataset_type == "video":
        return VideoFolderDataset(
            args.dataset_root,
            label_set=DummyLabelSet(),
            sampler=sampler,
            transform=transform,
        )
    else:
        raise ValueError("Unknown dataset type '{}'".format(args.dataset_type))
コード例 #2
0
    def test_transforms_are_passed_uint8_ndarray_video(self, gulp_path):
        dataset = GulpVideoDataset(gulp_path, transform=lambda f: f)

        vid, _ = dataset[0]

        assert type(vid) == np.ndarray
        assert vid.dtype == np.uint8
        assert vid.ndim == 4
コード例 #3
0
    def test_creating_gulp_video_dataset_from_gulp_directory(self, gulp_path):
        gulp_dir = GulpDirectory(gulp_path)

        dataset = GulpVideoDataset(gulp_dir.output_dir, gulp_directory=gulp_dir)

        assert dataset.gulp_dir == gulp_dir
        assert dataset.gulp_dir.output_dir == dataset.root_path
        assert len(gulp_dir.merged_meta_dict) == len(dataset)
コード例 #4
0
    def test_transform_is_called(self, gulp_path):
        transform = MockFramesOnlyTransform(lambda frames: frames)
        dataset = GulpVideoDataset(gulp_path, transform=transform)

        frames, _ = dataset[0]

        assert frames is not None
        transform.assert_called_once_with(frames)
コード例 #5
0
    def test_filtering_videos(self, gulp_path):
        video_ids = {"video1", "video2", "video3"}

        def filter(video_id: str):
            return video_id in video_ids

        dataset = GulpVideoDataset(gulp_path, filter=filter)

        assert len(dataset) == 3
        assert set(dataset._video_ids) == video_ids
コード例 #6
0
    def test_dataset_throws_error_if_root_path_is_different_from_gulp_dir_path(
        self, gulp_path
    ):
        gulp_dir = GulpDirectory(gulp_path)

        with pytest.raises(ValueError):
            GulpVideoDataset(
                Path(gulp_dir.output_dir).with_name("not-a-gulp-dir"),
                gulp_directory=gulp_dir,
            )
コード例 #7
0
    def test_transform_is_passed_target_if_it_supports_it(self, gulp_path):
        transform = MockFramesAndOptionalTargetTransform(lambda f: f, lambda t: t)
        dataset = GulpVideoDataset(
            gulp_path, transform=transform, label_set=DummyLabelSet(1)
        )

        frames, target = dataset[0]

        assert target == 1
        transform.assert_called_once_with(frames, target=target)
コード例 #8
0
def gulp_dataset(gulp_path):
    return GulpVideoDataset(gulp_path)
コード例 #9
0
def gulp_dataset(gulp_dir):
    return GulpVideoDataset(gulp_dir)