예제 #1
0
    def test_image_dataset_from_directory_color_modes(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=4, color_mode='rgba')
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             color_mode='rgba')
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 4))
        self.assertEqual(batch[0].dtype.name, 'float32')

        directory = self._prepare_directory(num_classes=4,
                                            color_mode='grayscale')
        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=8,
            image_size=(18, 18),
            color_mode='grayscale')
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 1))
        self.assertEqual(batch[0].dtype.name, 'float32')
예제 #2
0
    def test_image_dataset_from_directory_validation_split(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2, count=10)
        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=10,
            image_size=(18, 18),
            validation_split=0.2,
            subset='training',
            seed=1337)
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=10,
            image_size=(18, 18),
            validation_split=0.2,
            subset='validation',
            seed=1337)
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (2, 18, 18, 3))
예제 #3
0
    def test_image_dataset_from_directory_multiclass(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=4, count=15)

        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode=None)
        batch = next(iter(dataset))
        self.assertEqual(batch.shape, (8, 18, 18, 3))

        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode=None)
        sample_count = 0
        iterator = iter(dataset)
        for batch in dataset:
            sample_count += next(iterator).shape[0]
        self.assertEqual(sample_count, 15)

        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode="int")
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        self.assertEqual(batch[0].dtype.name, "float32")
        self.assertEqual(batch[1].shape, (8, ))
        self.assertEqual(batch[1].dtype.name, "int32")

        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=8,
            image_size=(18, 18),
            label_mode="categorical",
        )
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        self.assertEqual(batch[0].dtype.name, "float32")
        self.assertEqual(batch[1].shape, (8, 4))
        self.assertEqual(batch[1].dtype.name, "float32")
예제 #4
0
    def test_image_dataset_from_directory_standalone(self):
        # Test retrieving images without labels from a directory and its subdirs.
        if PIL is None:
            return  # Skip test if PIL is not available.

        # Save a few extra images in the parent directory.
        directory = self._prepare_directory(count=7, num_classes=2)
        for i, img in enumerate(self._get_images(3)):
            filename = 'image_%s.jpg' % (i, )
            img.save(os.path.join(directory, filename))

        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=5,
                                                             image_size=(18,
                                                                         18),
                                                             labels=None)
        batch = next(iter(dataset))
        # We return plain images
        self.assertEqual(batch.shape, (5, 18, 18, 3))
        self.assertEqual(batch.dtype.name, 'float32')
        # Count samples
        batch_count = 0
        sample_count = 0
        for batch in dataset:
            batch_count += 1
            sample_count += batch.shape[0]
        self.assertEqual(batch_count, 2)
        self.assertEqual(sample_count, 10)
예제 #5
0
    def test_image_dataset_from_directory_not_batched(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2, count=2)
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=None,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode=None,
                                                             shuffle=False)
        sample = next(iter(dataset))
        self.assertEqual(len(sample.shape), 3)
예제 #6
0
    def test_image_dataset_from_directory_crop_to_aspect_ratio(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2, count=5)
        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=5,
            image_size=(18, 18),
            crop_to_aspect_ratio=True)
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (5, 18, 18, 3))
예제 #7
0
    def test_image_dataset_from_directory_binary(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2)
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode="int")
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        self.assertEqual(batch[0].dtype.name, "float32")
        self.assertEqual(batch[1].shape, (8, ))
        self.assertEqual(batch[1].dtype.name, "int32")

        dataset = image_dataset.image_dataset_from_directory(
            directory, batch_size=8, image_size=(18, 18), label_mode="binary")
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        self.assertEqual(batch[0].dtype.name, "float32")
        self.assertEqual(batch[1].shape, (8, 1))
        self.assertEqual(batch[1].dtype.name, "float32")

        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=8,
            image_size=(18, 18),
            label_mode="categorical",
        )
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        self.assertEqual(batch[0].dtype.name, "float32")
        self.assertEqual(batch[1].shape, (8, 2))
        self.assertEqual(batch[1].dtype.name, "float32")
예제 #8
0
    def test_image_dataset_from_directory_manual_labels(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2, count=2)
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             labels=[0, 1],
                                                             shuffle=False)
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertAllClose(batch[1], [0, 1])
예제 #9
0
    def test_sample_count(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=4, count=15)
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode=None)
        sample_count = 0
        for batch in dataset:
            sample_count += batch.shape[0]
        self.assertEqual(sample_count, 15)
예제 #10
0
    def test_image_dataset_from_directory_follow_links(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2,
                                            count=25,
                                            nested_dirs=True)
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode=None,
                                                             follow_links=True)
        sample_count = 0
        for batch in dataset:
            sample_count += batch.shape[0]
        self.assertEqual(sample_count, 25)
예제 #11
0
    def test_static_shape_in_graph(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2)
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode='int')
        test_case = self

        @tf.function
        def symbolic_fn(ds):
            for x, _ in ds.take(1):
                test_case.assertListEqual(x.shape.as_list(), [None, 18, 18, 3])

        symbolic_fn(dataset)
예제 #12
0
    def test_image_dataset_from_directory_errors(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=3, count=5)

        with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           labels='other')

        with self.assertRaisesRegex(ValueError,
                                    '`label_mode` argument must be'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           label_mode='other')

        with self.assertRaisesRegex(ValueError, '`color_mode` must be one of'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           color_mode='other')

        with self.assertRaisesRegex(
                ValueError, 'only pass `class_names` if `labels="inferred"`'):
            _ = image_dataset.image_dataset_from_directory(
                directory,
                labels=[0, 0, 1, 1, 1],
                class_names=['class_0', 'class_1', 'class_2'])

        with self.assertRaisesRegex(
                ValueError,
                'Expected the lengths of `labels` to match the number of files'
        ):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           labels=[0, 0, 1, 1])

        with self.assertRaisesRegex(ValueError,
                                    '`class_names` passed did not match'):
            _ = image_dataset.image_dataset_from_directory(
                directory, class_names=['class_0', 'class_2'])

        with self.assertRaisesRegex(ValueError, 'there must be exactly 2'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           label_mode='binary')

        with self.assertRaisesRegex(
                ValueError, '`validation_split` must be between 0 and 1'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           validation_split=2)

        with self.assertRaisesRegex(ValueError,
                                    '`subset` must be either "training" or'):
            _ = image_dataset.image_dataset_from_directory(
                directory, validation_split=0.2, subset='other')

        with self.assertRaisesRegex(ValueError,
                                    '`validation_split` must be set'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           validation_split=0,
                                                           subset='training')

        with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
            _ = image_dataset.image_dataset_from_directory(
                directory, validation_split=0.2, subset='training')
예제 #13
0
 def test_image_dataset_from_directory_no_images(self):
     directory = self._prepare_directory(num_classes=2, count=0)
     with self.assertRaisesRegex(ValueError, 'No images found.'):
         _ = image_dataset.image_dataset_from_directory(directory)
예제 #14
0
    def test_image_dataset_from_directory_errors(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=3, count=5)

        with self.assertRaisesRegex(ValueError, "`labels` argument should be"):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           labels="other")

        with self.assertRaisesRegex(ValueError,
                                    "`label_mode` argument must be"):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           label_mode="other")

        with self.assertRaisesRegex(ValueError, "`color_mode` must be one of"):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           color_mode="other")

        with self.assertRaisesRegex(
                ValueError, 'only pass `class_names` if `labels="inferred"`'):
            _ = image_dataset.image_dataset_from_directory(
                directory,
                labels=[0, 0, 1, 1, 1],
                class_names=["class_0", "class_1", "class_2"],
            )

        with self.assertRaisesRegex(
                ValueError,
                "Expected the lengths of `labels` to match the number of files",
        ):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           labels=[0, 0, 1, 1])

        with self.assertRaisesRegex(ValueError,
                                    "`class_names` passed did not match"):
            _ = image_dataset.image_dataset_from_directory(
                directory, class_names=["class_0", "class_2"])

        with self.assertRaisesRegex(ValueError, "there must be exactly 2"):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           label_mode="binary")

        with self.assertRaisesRegex(
                ValueError, "`validation_split` must be between 0 and 1"):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           validation_split=2)

        with self.assertRaisesRegex(
                ValueError,
                '`subset` must be either "training", '
                '"validation" or "both"',
        ):
            _ = image_dataset.image_dataset_from_directory(
                directory, validation_split=0.2, subset="other")

        with self.assertRaisesRegex(ValueError,
                                    "`validation_split` must be set"):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           validation_split=0,
                                                           subset="training")

        with self.assertRaisesRegex(ValueError, "must provide a `seed`"):
            _ = image_dataset.image_dataset_from_directory(
                directory, validation_split=0.2, subset="training")