def test_mnist(self):
        dirpath = '/tmp/mnist_test_sample_directory'
        if not (os.path.exists(dirpath)):
            os.makedirs(dirpath)

        ds, img_shape, num_examples, num_classes = DataGenerator.get_dataset_from_name(
            ds_name='mnist', ds_path=dirpath, split='train[0:3]')

        assert (num_classes == 10)
        assert (num_examples == 3)
        assert (tuple(img_shape) == (28, 28, 1))
        assert (isinstance(ds, tf.data.Dataset))

        ds, img_shape, num_examples, num_classes = DataGenerator.get_dataset_from_name(
            ds_name='mnist', ds_path=dirpath, split='test[0:3]')

        assert (num_classes == 10)
        assert (num_examples == 3)
        assert (tuple(img_shape) == (28, 28, 1))
        assert (isinstance(ds, tf.data.Dataset))
 def test_unsupported_dataset(self):
     with self.assertRaises(NameError):
         _ = DataGenerator.get_dataset_from_name(ds_name='foo')
 def test_mnist_wrong_dir(self):
     with self.assertRaises(NameError):
         _ = DataGenerator.get_dataset_from_name(ds_name='mnist',
                                                 ds_path='foo')
Beispiel #4
0
    def get_dataset(dataset_name: str,
                    dataset_path: str,
                    split: str,
                    img_datatype: tf.dtypes.DType,
                    micro_batch_size: int,
                    shuffle: bool = False,
                    accelerator_side_preprocess: bool = True,
                    eight_bit_transfer: Optional[EightBitTransfer] = None,
                    apply_preprocessing: bool = True,
                    pipeline_num_parallel: int = 48,
                    seed: Optional[int] = None):

        logging.info(f'dataset_name = {dataset_name}')

        if popdist.getNumInstances() == 1:
            logging.info(f'Since the training is run in a single process, setting dataset pipeline threading '
                         f'and prefetching buffer size to tf.data.AUTOTUNE.')
            pipeline_num_parallel = prefetch_size = tf.data.AUTOTUNE
        else:
            prefetch_size = PREFETCH_BUFFER_SIZE
            logging.info(f'Setting number of threads for the dataset pipeline to {pipeline_num_parallel}, '
                         f'and the prefetching buffer size to {prefetch_size}.')

        ds, img_shape, dataset_size, num_classes = DataGenerator.get_dataset_from_name(
            ds_name=dataset_name, ds_path=dataset_path, split=split)

        preprocess_fn = None
        if apply_preprocessing:
            if dataset_name == 'cifar10':
                ds, preprocess_fn = DataTransformer.cifar_preprocess(
                    ds,
                    buffer_size=dataset_size,
                    img_type=img_datatype,
                    is_training=(split == 'train'),
                    accelerator_side_preprocess=accelerator_side_preprocess,
                    pipeline_num_parallel=pipeline_num_parallel,
                )
            elif dataset_name == 'imagenet':
                ds, preprocess_fn = DataTransformer.imagenet_preprocessing(
                    ds,
                    img_type=img_datatype,
                    is_training=(split == 'train'),
                    accelerator_side_preprocess=accelerator_side_preprocess,
                    pipeline_num_parallel=pipeline_num_parallel,
                    seed=seed
                )
                if shuffle:
                    # Shuffle the input files
                    ds = ds.shuffle(buffer_size=IMAGENET_SHUFFLE_BUFFER_SIZE)
            else:
                ds = DataTransformer.cache_shuffle(ds, buffer_size=dataset_size, shuffle=(split == 'train'))
                ds = DataTransformer.normalization(ds, img_type=img_datatype)
                preprocess_fn = None


            if eight_bit_transfer is not None:
                ds = ds.map(lambda x, y: (eight_bit_transfer.compress(x), y), num_parallel_calls=pipeline_num_parallel)


            ds = ds.batch(batch_size=micro_batch_size, drop_remainder=True)
            ds = ds.repeat().prefetch(prefetch_size)

            cpu_memory_usage = psutil.virtual_memory().percent

            if cpu_memory_usage > 100:
                logging.warning(f'cpu_memory_usage is {cpu_memory_usage} > 100% so your program is likely to crash')

        return ds, img_shape, dataset_size, num_classes, preprocess_fn