Beispiel #1
0
    def test_image_dataset_from_directory_validation_split(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2, count=10)
        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=10,
            image_size=(18, 18),
            validation_split=0.2,
            subset='training',
            seed=1337)
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=10,
            image_size=(18, 18),
            validation_split=0.2,
            subset='validation',
            seed=1337)
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (2, 18, 18, 3))
Beispiel #2
0
  def test_image_dataset_from_directory_binary(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2)
    dataset = image_dataset.image_dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='int')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    self.assertEqual(batch[0].dtype.name, 'float32')
    self.assertEqual(batch[1].shape, (8,))
    self.assertEqual(batch[1].dtype.name, 'int32')

    dataset = image_dataset.image_dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='binary')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    self.assertEqual(batch[0].dtype.name, 'float32')
    self.assertEqual(batch[1].shape, (8, 1))
    self.assertEqual(batch[1].dtype.name, 'float32')

    dataset = image_dataset.image_dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='categorical')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    self.assertEqual(batch[0].dtype.name, 'float32')
    self.assertEqual(batch[1].shape, (8, 2))
    self.assertEqual(batch[1].dtype.name, 'float32')
Beispiel #3
0
    def test_image_dataset_from_directory_color_modes(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=4, color_mode='rgba')
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             color_mode='rgba')
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 4))
        self.assertEqual(batch[0].dtype.name, 'float32')

        directory = self._prepare_directory(num_classes=4,
                                            color_mode='grayscale')
        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=8,
            image_size=(18, 18),
            color_mode='grayscale')
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 1))
        self.assertEqual(batch[0].dtype.name, 'float32')
Beispiel #4
0
    def test_image_dataset_from_directory_multiclass(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=4, count=15)

        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode=None)
        batch = next(iter(dataset))
        self.assertEqual(batch.shape, (8, 18, 18, 3))

        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode=None)
        sample_count = 0
        iterator = iter(dataset)
        for batch in dataset:
            sample_count += next(iterator).shape[0]
        self.assertEqual(sample_count, 15)

        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=8,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode='int')
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        self.assertEqual(batch[0].dtype.name, 'float32')
        self.assertEqual(batch[1].shape, (8, ))
        self.assertEqual(batch[1].dtype.name, 'int32')

        dataset = image_dataset.image_dataset_from_directory(
            directory,
            batch_size=8,
            image_size=(18, 18),
            label_mode='categorical')
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, 18, 18, 3))
        self.assertEqual(batch[0].dtype.name, 'float32')
        self.assertEqual(batch[1].shape, (8, 4))
        self.assertEqual(batch[1].dtype.name, 'float32')
Beispiel #5
0
    def test_image_dataset_from_directory_standalone(self):
        # Test retrieving images without labels from a directory and its subdirs.
        if PIL is None:
            return  # Skip test if PIL is not available.

        # Save a few extra images in the parent directory.
        directory = self._prepare_directory(count=7, num_classes=2)
        for i, img in enumerate(self._get_images(3)):
            filename = 'image_%s.jpg' % (i, )
            img.save(os.path.join(directory, filename))

        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=5,
                                                             image_size=(18,
                                                                         18),
                                                             labels=None)
        batch = next(iter(dataset))
        # We return plain images
        self.assertEqual(batch.shape, (5, 18, 18, 3))
        self.assertEqual(batch.dtype.name, 'float32')
        # Count samples
        batch_count = 0
        sample_count = 0
        for batch in dataset:
            batch_count += 1
            sample_count += batch.shape[0]
        self.assertEqual(batch_count, 2)
        self.assertEqual(sample_count, 10)
Beispiel #6
0
  def test_image_dataset_from_directory_crop_to_aspect_ratio(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2, count=5)
    dataset = image_dataset.image_dataset_from_directory(
        directory, batch_size=5, image_size=(18, 18), crop_to_aspect_ratio=True)
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (5, 18, 18, 3))
Beispiel #7
0
  def test_image_dataset_from_directory_manual_labels(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2, count=2)
    dataset = image_dataset.image_dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18),
        labels=[0, 1], shuffle=False)
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertAllClose(batch[1], [0, 1])
Beispiel #8
0
  def test_sample_count(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=4, count=15)
    dataset = image_dataset.image_dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode=None)
    sample_count = 0
    for batch in dataset:
      sample_count += batch.shape[0]
    self.assertEqual(sample_count, 15)
Beispiel #9
0
    def test_image_dataset_from_directory_not_batched(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=2, count=2)
        dataset = image_dataset.image_dataset_from_directory(directory,
                                                             batch_size=None,
                                                             image_size=(18,
                                                                         18),
                                                             label_mode=None,
                                                             shuffle=False)
        sample = next(iter(dataset))
        self.assertEqual(len(sample.shape), 3)
Beispiel #10
0
  def test_image_dataset_from_directory_follow_links(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2, count=25,
                                        nested_dirs=True)
    dataset = image_dataset.image_dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode=None,
        follow_links=True)
    sample_count = 0
    for batch in dataset:
      sample_count += batch.shape[0]
    self.assertEqual(sample_count, 25)
Beispiel #11
0
  def test_static_shape_in_graph(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2)
    dataset = image_dataset.image_dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='int')
    test_case = self

    @tf.function
    def symbolic_fn(ds):
      for x, _ in ds.take(1):
        test_case.assertListEqual(x.shape.as_list(), [None, 18, 18, 3])

    symbolic_fn(dataset)
Beispiel #12
0
    def test_image_dataset_from_directory_errors(self):
        if PIL is None:
            return  # Skip test if PIL is not available.

        directory = self._prepare_directory(num_classes=3, count=5)

        with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           labels='other')

        with self.assertRaisesRegex(ValueError,
                                    '`label_mode` argument must be'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           label_mode='other')

        with self.assertRaisesRegex(ValueError, '`color_mode` must be one of'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           color_mode='other')

        with self.assertRaisesRegex(
                ValueError,
                'only pass `class_names` if the labels are inferred'):
            _ = image_dataset.image_dataset_from_directory(
                directory,
                labels=[0, 0, 1, 1, 1],
                class_names=['class_0', 'class_1', 'class_2'])

        with self.assertRaisesRegex(
                ValueError,
                'Expected the lengths of `labels` to match the number of files'
        ):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           labels=[0, 0, 1, 1])

        with self.assertRaisesRegex(ValueError,
                                    '`class_names` passed did not match'):
            _ = image_dataset.image_dataset_from_directory(
                directory, class_names=['class_0', 'class_2'])

        with self.assertRaisesRegex(ValueError,
                                    'there must exactly 2 classes'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           label_mode='binary')

        with self.assertRaisesRegex(
                ValueError, '`validation_split` must be between 0 and 1'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           validation_split=2)

        with self.assertRaisesRegex(ValueError,
                                    '`subset` must be either "training" or'):
            _ = image_dataset.image_dataset_from_directory(
                directory, validation_split=0.2, subset='other')

        with self.assertRaisesRegex(ValueError,
                                    '`validation_split` must be set'):
            _ = image_dataset.image_dataset_from_directory(directory,
                                                           validation_split=0,
                                                           subset='training')

        with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
            _ = image_dataset.image_dataset_from_directory(
                directory, validation_split=0.2, subset='training')
Beispiel #13
0
 def test_image_dataset_from_directory_no_images(self):
     directory = self._prepare_directory(num_classes=2, count=0)
     with self.assertRaisesRegex(ValueError, 'No images found.'):
         _ = image_dataset.image_dataset_from_directory(directory)