def test_malformed_directory(self): imagenet_original_data_path = '/localdata/datasets/imagenet-raw-data' output_directory = '/tmp/temporary_imagenet_dataset_directory' split = 'val' with self.assertRaises(NameError): # if the path imagenet_original_data_path + split doesn't exist the function will throw an error DataGenerator.build_imagenet_tf_record( imagenet_original_data_path, split, output_directory=output_directory)
def test_build_imagenet_validation(self): imagenet_original_data_path = '/localdata/datasets/imagenet-raw-data' output_directory = '/tmp/temporary_imagenet_dataset_directory' DataGenerator.build_imagenet_tf_record( imagenet_original_data_path, 'validation', output_directory=output_directory) ds, img_shape, num_examples, num_classes = DataGenerator.get_imagenet( output_directory, 'validation') self.assertTrue(isinstance(ds, tf.data.Dataset)) self.assertEqual(img_shape, (224, 224, 3)) self.assertEqual(num_examples, 50000) self.assertEqual(num_classes, 1000) shutil.rmtree(output_directory)