예제 #1
0
  def test_dataset_from_directory_binary(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2)
    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='int')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    self.assertEqual(batch[0].dtype.name, 'float32')
    self.assertEqual(batch[1].shape, (8,))
    self.assertEqual(batch[1].dtype.name, 'int32')

    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='binary')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    self.assertEqual(batch[0].dtype.name, 'float32')
    self.assertEqual(batch[1].shape, (8, 1))
    self.assertEqual(batch[1].dtype.name, 'float32')

    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='categorical')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    self.assertEqual(batch[0].dtype.name, 'float32')
    self.assertEqual(batch[1].shape, (8, 2))
    self.assertEqual(batch[1].dtype.name, 'float32')
예제 #2
0
  def test_dataset_from_directory_validation_split(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2, count=10)
    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=10, image_size=(18, 18),
        validation_split=0.2, subset='training')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=10, image_size=(18, 18),
        validation_split=0.2, subset='validation')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (2, 18, 18, 3))
예제 #3
0
  def test_dataset_from_directory_color_modes(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=4, color_mode='rgba')
    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), color_mode='rgba')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 4))
    self.assertEqual(batch[0].dtype.name, 'float32')

    directory = self._prepare_directory(num_classes=4, color_mode='grayscale')
    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), color_mode='grayscale')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 1))
    self.assertEqual(batch[0].dtype.name, 'float32')
예제 #4
0
  def test_dataset_from_directory_manual_labels(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2, count=2)
    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18),
        labels=[0, 1], shuffle=False)
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertAllClose(batch[1], [0, 1])
예제 #5
0
  def test_sample_count(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=4, count=15)
    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode=None)
    sample_count = 0
    for batch in dataset:
      sample_count += batch.shape[0]
    self.assertEqual(sample_count, 15)
예제 #6
0
  def test_dataset_from_directory_follow_links(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=2, count=25,
                                        nested_dirs=True)
    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode=None,
        follow_links=True)
    sample_count = 0
    for batch in dataset:
      sample_count += batch.shape[0]
    self.assertEqual(sample_count, 25)
예제 #7
0
  def test_dataset_from_directory_errors(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=3, count=5)

    with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
      _ = image_pipeline.dataset_from_directory(
          directory, labels=None)

    with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
      _ = image_pipeline.dataset_from_directory(directory, label_mode='other')

    with self.assertRaisesRegex(ValueError, '`color_mode` must be one of'):
      _ = image_pipeline.dataset_from_directory(directory, color_mode='other')

    with self.assertRaisesRegex(
        ValueError, 'only pass `class_names` if the labels are inferred'):
      _ = image_pipeline.dataset_from_directory(
          directory, labels=[0, 0, 1, 1, 1],
          class_names=['class_0', 'class_1', 'class_2'])

    with self.assertRaisesRegex(
        ValueError,
        'Expected the lengths of `labels` to match the number of images'):
      _ = image_pipeline.dataset_from_directory(directory, labels=[0, 0, 1, 1])

    with self.assertRaisesRegex(
        ValueError, '`class_names` passed did not match'):
      _ = image_pipeline.dataset_from_directory(
          directory, class_names=['class_0', 'class_2'])

    with self.assertRaisesRegex(ValueError, 'there must exactly 2 classes'):
      _ = image_pipeline.dataset_from_directory(directory, label_mode='binary')

    with self.assertRaisesRegex(ValueError,
                                '`validation_split` must be between 0 and 1'):
      _ = image_pipeline.dataset_from_directory(directory, validation_split=2)

    with self.assertRaisesRegex(ValueError,
                                '`subset` must be either "training" or'):
      _ = image_pipeline.dataset_from_directory(
          directory, validation_split=0.2, subset='other')
예제 #8
0
  def test_dataset_from_directory_multiclass(self):
    if PIL is None:
      return  # Skip test if PIL is not available.

    directory = self._prepare_directory(num_classes=4, count=15)

    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode=None)
    batch = next(iter(dataset))
    self.assertEqual(batch.shape, (8, 18, 18, 3))

    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode=None)
    sample_count = 0
    iterator = iter(dataset)
    for batch in dataset:
      sample_count += next(iterator).shape[0]
    self.assertEqual(sample_count, 15)

    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='int')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    self.assertEqual(batch[0].dtype.name, 'float32')
    self.assertEqual(batch[1].shape, (8,))
    self.assertEqual(batch[1].dtype.name, 'int32')

    dataset = image_pipeline.dataset_from_directory(
        directory, batch_size=8, image_size=(18, 18), label_mode='categorical')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
    self.assertEqual(batch[0].dtype.name, 'float32')
    self.assertEqual(batch[1].shape, (8, 4))
    self.assertEqual(batch[1].dtype.name, 'float32')