예제 #1
0
def make_dataset(
    args, sampler: Optional[FrameSampler] = None, transform=None
) -> VideoDataset:
    if sampler is None:
        sampler = FullVideoSampler()
    dataset_type = args.dataset_type.lower()
    if dataset_type == "gulp":
        if transform is not None:
            transform = Compose([TimeApply(ToPILImage()), transform])
        return GulpVideoDataset(
            args.dataset_root,
            label_set=DummyLabelSet(),
            sampler=sampler,
            transform=transform,
        )
    elif dataset_type == "image":
        return ImageFolderVideoDataset(
            args.dataset_root,
            args.image_filename_template,
            label_set=DummyLabelSet(),
            sampler=sampler,
            transform=transform,
        )
    elif dataset_type == "video":
        return VideoFolderDataset(
            args.dataset_root,
            label_set=DummyLabelSet(),
            sampler=sampler,
            transform=transform,
        )
    else:
        raise ValueError("Unknown dataset type '{}'".format(args.dataset_type))
    def test_using_custom_frame_counter(self, image_folder):
        frame_counter = lambda path: 10

        dataset = ImageFolderVideoDataset(
            image_folder, "frame_{:05d}.jpg", frame_counter=frame_counter
        )

        assert all([length == 10 for length in dataset.video_lengths])
예제 #3
0
    def test_all_videos_folders_are_present_in_video_dirs_by_default(
            self, dataset_dir):
        video_count = 10
        self.make_video_dirs(dataset_dir, video_count)

        dataset = ImageFolderVideoDataset(dataset_dir, "frame_{:05d}.jpg")

        assert len(dataset._video_dirs) == video_count
    def test_video_ids(self, dataset_dir, fs):
        video_count = 10
        self.make_video_files(dataset_dir, fs, video_count)

        dataset = ImageFolderVideoDataset(dataset_dir,
                                          "frame_{:05d}.jpg",
                                          frame_counter=(lambda path: 10))

        assert list(map(lambda p: p.name, dataset.video_ids)) == sorted(
            ["video{}.mp4".format(i) for i in range(0, video_count)])
예제 #5
0
    def test_transform_is_applied(self, dataset_dir):
        self.make_video_dirs(dataset_dir, 1)
        transform = MockFramesOnlyTransform(lambda frames: frames)

        dataset = ImageFolderVideoDataset(dataset_dir,
                                          "frame_{:05d}.jpg",
                                          transform=transform)

        frames = dataset[0]

        transform.assert_called_once_with(frames)
예제 #6
0
    def test_labels_are_accessible(self, dataset_dir):
        self.make_video_dirs(dataset_dir, 10)

        dataset = ImageFolderVideoDataset(
            dataset_dir,
            "frame_{:05d}.jpg",
            label_set=LambdaLabelSet(lambda p: int(p[-1])),
        )

        assert 10 == len(dataset.labels)
        assert all([label == i for i, label in enumerate(dataset.labels)])
예제 #7
0
    def test_filtering_video_folders(self, dataset_dir):
        self.make_video_dirs(dataset_dir, 10)

        def filter(video_path: Path):
            return video_path.name.endswith(("1", "2", "3"))

        dataset = ImageFolderVideoDataset(dataset_dir,
                                          "frame_{:05d}.jpg",
                                          filter=filter)

        assert len(dataset._video_dirs) == 3
        assert dataset._video_dirs[0].name == "video1"
        assert dataset._video_dirs[1].name == "video2"
        assert dataset._video_dirs[2].name == "video3"
예제 #8
0
    def test_transform_is_passed_target_if_it_supports_it(self, dataset_dir):
        self.make_video_dirs(dataset_dir, 1)
        transform = MockFramesAndOptionalTargetTransform(
            lambda f: f, lambda t: t)
        dataset = ImageFolderVideoDataset(
            dataset_dir,
            "frame_{:05d}.jpg",
            transform=transform,
            label_set=DummyLabelSet(1),
        )

        frames, target = dataset[0]

        assert target == 1
        transform.assert_called_once_with(frames, target=target)
def image_folder_video_dataset(image_folder):
    return ImageFolderVideoDataset(image_folder, filename_template="frame_{:05d}.jpg")