def test_decoders(self): """Test with decoders (e.g., SkipDecoding).""" images = [ 'root_dir/train/label1/img1.png', 'root_dir/train/label3/img3.png', 'root_dir/train/label3/img1.png', 'root_dir/train/label3/img2.png', 'root_dir/train/label2/img1.png', 'root_dir/train/label2/img2.png', ] with tfds.testing.MockFs() as fs: for file in images: fs.add_file(file) split_examples, labels = image_folder._get_split_label_images('root_dir') builder = tfds.ImageFolder( root_dir='root_dir', dtype=tf.uint8, shape=(128, 128, 1), ) # Decoded images should be found if passing decoders=None ds = builder.as_dataset(split="train", decoders=None) expected_element_spec = { 'image/filename': tf.TensorSpec(shape=(), dtype=tf.string), 'image': tf.TensorSpec(shape=(128, 128, 1), dtype=tf.uint8), 'label': tf.TensorSpec(shape=(), dtype=tf.int64), } self.assertEqual(ds.element_spec, expected_element_spec) # Encoded images should be found if passing decoders=SkipDecoding() ds = builder.as_dataset(split="train", decoders={'image': tfds.decode.SkipDecoding()}) expected_element_spec = { 'image/filename': tf.TensorSpec(shape=(), dtype=tf.string), 'image': tf.TensorSpec(shape=(), dtype=tf.string), 'label': tf.TensorSpec(shape=(), dtype=tf.int64), } self.assertEqual(ds.element_spec, expected_element_spec) # Unused keys should throw ValueError def ds_fn(): builder.as_dataset(split="train", decoders={'text': tfds.decode.SkipDecoding()}) self.assertRaises(ValueError, ds_fn)
def test_decoders(self): """Test with decoders (e.g., SkipDecoding).""" images = [ 'root_dir/train/label1/img1.png', 'root_dir/train/label3/img3.png', 'root_dir/train/label3/img1.png', 'root_dir/train/label3/img2.png', 'root_dir/train/label2/img1.png', 'root_dir/train/label2/img2.png', ] with tfds.testing.MockFs() as fs: for file in images: fs.add_file(file) builder = tfds.ImageFolder(root_dir='root_dir', dtype=tf.uint8) # Decoded images should be found if passing decoders=None ds = builder.as_dataset(split='train', decoders=None) self.assertEqual( ds.element_spec, { 'image/filename': tf.TensorSpec(shape=(), dtype=tf.string), 'image': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8), 'label': tf.TensorSpec(shape=(), dtype=tf.int64), }) # Encoded images should be found if passing decoders=SkipDecoding() ds = builder.as_dataset( split='train', decoders={'image': tfds.decode.SkipDecoding()}) self.assertEqual( ds.element_spec, { 'image/filename': tf.TensorSpec(shape=(), dtype=tf.string), 'image': tf.TensorSpec(shape=(), dtype=tf.string), 'label': tf.TensorSpec(shape=(), dtype=tf.int64), }) # Invalid keys should throw ValueError with self.assertRaisesWithPredicateMatch(ValueError, 'Unrecognized keys'): builder.as_dataset( split='train', decoders={'invalid_key': tfds.decode.SkipDecoding()})
def test_properties(self): images = [ 'root_dir/train/label1/img1.png', 'root_dir/train/label3/img3.png', 'root_dir/train/label3/img1.png', 'root_dir/train/label3/img2.png', 'root_dir/train/label2/img1.png', 'root_dir/train/label2/img2.png', 'root_dir/val/label1/img1.png', 'root_dir/val/label2/img2.png', 'root_dir/test/label1/img1.png', 'root_dir/test/label2/img1.png', 'root_dir/test/label4/img1.PNG', 'root_dir/test/label4/unsuported.txt', ] with tfds.testing.MockFs() as fs: for file in images: fs.add_file(file) split_examples, labels = image_folder._get_split_label_images( 'root_dir') builder = tfds.ImageFolder(root_dir='root_dir') builder_params = tfds.ImageFolder(root_dir='root_dir', dtype=tf.uint16, shape=(128, 128, 1)) self.assertEqual( split_examples, { 'train': [ image_folder._Example( image_path='root_dir/train/label2/img1.png', label='label2'), image_folder._Example( image_path='root_dir/train/label3/img3.png', label='label3'), image_folder._Example( image_path='root_dir/train/label3/img2.png', label='label3'), image_folder._Example( image_path='root_dir/train/label3/img1.png', label='label3'), image_folder._Example( image_path='root_dir/train/label2/img2.png', label='label2'), image_folder._Example( image_path='root_dir/train/label1/img1.png', label='label1'), ], 'val': [ image_folder._Example( image_path='root_dir/val/label2/img2.png', label='label2'), image_folder._Example( image_path='root_dir/val/label1/img1.png', label='label1'), ], 'test': [ image_folder._Example( image_path='root_dir/test/label1/img1.png', label='label1'), image_folder._Example( image_path='root_dir/test/label2/img1.png', label='label2'), image_folder._Example( image_path='root_dir/test/label4/img1.PNG', label='label4'), ], }) self.assertEqual(builder.info.splits['train'].num_examples, 6) self.assertEqual(builder.info.splits['val'].num_examples, 2) self.assertEqual(builder.info.splits['test'].num_examples, 3) expected_labels = [ 'label1', 'label2', 'label3', 'label4', ] self.assertEqual(expected_labels, labels) self.assertEqual(builder.info.features['label'].names, expected_labels) self.assertEqual(builder.info.features['image'].shape, (None, None, 3)) self.assertEqual(builder.info.features['image'].dtype, tf.uint8) self.assertEqual(builder_params.info.features['image'].shape, (128, 128, 1)) self.assertEqual(builder_params.info.features['image'].dtype, tf.uint16)