def test_video_dataset(self):

        if not VIDEO_DATASET_AVAILABLE:
            tmp_dir = tempfile.mkdtemp()
            # simulate a video
            # the video dataset will check to see whether there exists a file
            # with a video extension, it's enough to fake a video file here
            path = os.path.join(tmp_dir, 'my_file.png')
            dataset = torchvision.datasets.FakeData(size=1,
                                                    image_size=(3, 32, 32))
            image, _ = dataset[0]
            image.save(path)
            os.rename(path, os.path.join(tmp_dir, 'my_file.avi'))
            with self.assertRaises(ImportError):
                dataset = LightlyDataset(from_folder=tmp_dir)

            warnings.warn(
                'Did not test video dataset because of missing requirements')
            shutil.rmtree(tmp_dir)
            return

        self.create_video_dataset()
        dataset = LightlyDataset(from_folder=self.input_dir)

        out_dir = tempfile.mkdtemp()
        dataset.dump(out_dir)
        self.assertEqual(len(os.listdir(out_dir)), len(dataset))
    def test_create_lightly_dataset_from_folder(self):
        n_subfolders = 5
        n_samples_per_subfolder = 10
        n_tot_files = n_subfolders * n_samples_per_subfolder

        dataset_dir, folder_names, sample_names = self.create_dataset(
            n_subfolders, n_samples_per_subfolder)

        dataset = LightlyDataset(from_folder=dataset_dir)
        filenames = dataset.get_filenames()

        fnames = []
        for dir_name in folder_names:
            for fname in sample_names:
                fnames.append(os.path.join(dir_name, fname))

        self.assertEqual(len(filenames), n_tot_files)
        self.assertEqual(len(dataset), n_tot_files)
        self.assertListEqual(sorted(fnames), sorted(filenames))

        out_dir = tempfile.mkdtemp()
        dataset.dump(out_dir)
        self.assertEqual(
            sum(
                len(os.listdir(os.path.join(out_dir, subdir)))
                for subdir in os.listdir(out_dir)),
            len(dataset),
        )

        shutil.rmtree(dataset_dir)
        shutil.rmtree(out_dir)
Beispiel #3
0
    def test_video_dataset_available(self):
        self.create_video_dataset()
        dataset = LightlyDataset(input_dir=self.input_dir)

        out_dir = tempfile.mkdtemp()
        dataset.dump(out_dir, dataset.get_filenames()[(len(dataset) // 2):])
        self.assertEqual(len(os.listdir(out_dir)), len(dataset) // 2)
        for filename in os.listdir(out_dir):
            self.assertIn(filename,
                          dataset.get_filenames()[(len(dataset) // 2):])