Ejemplo n.º 1
0
    def test_malformed_directory(self):
        imagenet_original_data_path = '/localdata/datasets/imagenet-raw-data'
        output_directory = '/tmp/temporary_imagenet_dataset_directory'
        split = 'val'

        with self.assertRaises(NameError):
            # if the path imagenet_original_data_path + split doesn't exist the function will throw an error
            DataGenerator.build_imagenet_tf_record(
                imagenet_original_data_path,
                split,
                output_directory=output_directory)
Ejemplo n.º 2
0
    def test_build_imagenet_validation(self):

        imagenet_original_data_path = '/localdata/datasets/imagenet-raw-data'
        output_directory = '/tmp/temporary_imagenet_dataset_directory'

        DataGenerator.build_imagenet_tf_record(
            imagenet_original_data_path,
            'validation',
            output_directory=output_directory)

        ds, img_shape, num_examples, num_classes = DataGenerator.get_imagenet(
            output_directory, 'validation')

        self.assertTrue(isinstance(ds, tf.data.Dataset))
        self.assertEqual(img_shape, (224, 224, 3))
        self.assertEqual(num_examples, 50000)
        self.assertEqual(num_classes, 1000)

        shutil.rmtree(output_directory)