Ejemplo n.º 1
0
  def test_decoders(self):
    """Test with decoders (e.g., SkipDecoding)."""
    images = [
        'root_dir/train/label1/img1.png',
        'root_dir/train/label3/img3.png',
        'root_dir/train/label3/img1.png',
        'root_dir/train/label3/img2.png',
        'root_dir/train/label2/img1.png',
        'root_dir/train/label2/img2.png',
    ]

    with tfds.testing.MockFs() as fs:
      for file in images:
        fs.add_file(file)

      split_examples, labels = image_folder._get_split_label_images('root_dir')
      builder = tfds.ImageFolder(
          root_dir='root_dir', dtype=tf.uint8, shape=(128, 128, 1),
      )

      # Decoded images should be found if passing decoders=None
      ds = builder.as_dataset(split="train",
                              decoders=None)
      expected_element_spec = {
          'image/filename': tf.TensorSpec(shape=(), dtype=tf.string),
          'image': tf.TensorSpec(shape=(128, 128, 1), dtype=tf.uint8),
          'label': tf.TensorSpec(shape=(), dtype=tf.int64),
      }
      self.assertEqual(ds.element_spec, expected_element_spec)

      # Encoded images should be found if passing decoders=SkipDecoding()
      ds = builder.as_dataset(split="train",
                              decoders={'image': tfds.decode.SkipDecoding()})
      expected_element_spec = {
          'image/filename': tf.TensorSpec(shape=(), dtype=tf.string),
          'image': tf.TensorSpec(shape=(), dtype=tf.string),
          'label': tf.TensorSpec(shape=(), dtype=tf.int64),
      }
      self.assertEqual(ds.element_spec, expected_element_spec)

      # Unused keys should throw ValueError
      def ds_fn():
        builder.as_dataset(split="train",
                           decoders={'text': tfds.decode.SkipDecoding()})
      self.assertRaises(ValueError, ds_fn)
Ejemplo n.º 2
0
    def test_decoders(self):
        """Test with decoders (e.g., SkipDecoding)."""
        images = [
            'root_dir/train/label1/img1.png',
            'root_dir/train/label3/img3.png',
            'root_dir/train/label3/img1.png',
            'root_dir/train/label3/img2.png',
            'root_dir/train/label2/img1.png',
            'root_dir/train/label2/img2.png',
        ]

        with tfds.testing.MockFs() as fs:
            for file in images:
                fs.add_file(file)

            builder = tfds.ImageFolder(root_dir='root_dir', dtype=tf.uint8)

            # Decoded images should be found if passing decoders=None
            ds = builder.as_dataset(split='train', decoders=None)
            self.assertEqual(
                ds.element_spec, {
                    'image/filename': tf.TensorSpec(shape=(), dtype=tf.string),
                    'image': tf.TensorSpec(shape=(None, None, 3),
                                           dtype=tf.uint8),
                    'label': tf.TensorSpec(shape=(), dtype=tf.int64),
                })

            # Encoded images should be found if passing decoders=SkipDecoding()
            ds = builder.as_dataset(
                split='train', decoders={'image': tfds.decode.SkipDecoding()})
            self.assertEqual(
                ds.element_spec, {
                    'image/filename': tf.TensorSpec(shape=(), dtype=tf.string),
                    'image': tf.TensorSpec(shape=(), dtype=tf.string),
                    'label': tf.TensorSpec(shape=(), dtype=tf.int64),
                })

            # Invalid keys should throw ValueError
            with self.assertRaisesWithPredicateMatch(ValueError,
                                                     'Unrecognized keys'):
                builder.as_dataset(
                    split='train',
                    decoders={'invalid_key': tfds.decode.SkipDecoding()})
Ejemplo n.º 3
0
    def test_properties(self):
        images = [
            'root_dir/train/label1/img1.png',
            'root_dir/train/label3/img3.png',
            'root_dir/train/label3/img1.png',
            'root_dir/train/label3/img2.png',
            'root_dir/train/label2/img1.png',
            'root_dir/train/label2/img2.png',
            'root_dir/val/label1/img1.png',
            'root_dir/val/label2/img2.png',
            'root_dir/test/label1/img1.png',
            'root_dir/test/label2/img1.png',
            'root_dir/test/label4/img1.PNG',
            'root_dir/test/label4/unsuported.txt',
        ]

        with tfds.testing.MockFs() as fs:
            for file in images:
                fs.add_file(file)

            split_examples, labels = image_folder._get_split_label_images(
                'root_dir')
            builder = tfds.ImageFolder(root_dir='root_dir')
            builder_params = tfds.ImageFolder(root_dir='root_dir',
                                              dtype=tf.uint16,
                                              shape=(128, 128, 1))

        self.assertEqual(
            split_examples, {
                'train': [
                    image_folder._Example(
                        image_path='root_dir/train/label2/img1.png',
                        label='label2'),
                    image_folder._Example(
                        image_path='root_dir/train/label3/img3.png',
                        label='label3'),
                    image_folder._Example(
                        image_path='root_dir/train/label3/img2.png',
                        label='label3'),
                    image_folder._Example(
                        image_path='root_dir/train/label3/img1.png',
                        label='label3'),
                    image_folder._Example(
                        image_path='root_dir/train/label2/img2.png',
                        label='label2'),
                    image_folder._Example(
                        image_path='root_dir/train/label1/img1.png',
                        label='label1'),
                ],
                'val': [
                    image_folder._Example(
                        image_path='root_dir/val/label2/img2.png',
                        label='label2'),
                    image_folder._Example(
                        image_path='root_dir/val/label1/img1.png',
                        label='label1'),
                ],
                'test': [
                    image_folder._Example(
                        image_path='root_dir/test/label1/img1.png',
                        label='label1'),
                    image_folder._Example(
                        image_path='root_dir/test/label2/img1.png',
                        label='label2'),
                    image_folder._Example(
                        image_path='root_dir/test/label4/img1.PNG',
                        label='label4'),
                ],
            })
        self.assertEqual(builder.info.splits['train'].num_examples, 6)
        self.assertEqual(builder.info.splits['val'].num_examples, 2)
        self.assertEqual(builder.info.splits['test'].num_examples, 3)

        expected_labels = [
            'label1',
            'label2',
            'label3',
            'label4',
        ]
        self.assertEqual(expected_labels, labels)
        self.assertEqual(builder.info.features['label'].names, expected_labels)

        self.assertEqual(builder.info.features['image'].shape, (None, None, 3))
        self.assertEqual(builder.info.features['image'].dtype, tf.uint8)
        self.assertEqual(builder_params.info.features['image'].shape,
                         (128, 128, 1))
        self.assertEqual(builder_params.info.features['image'].dtype,
                         tf.uint16)