def _download_cifar_split(split, is_training): """Build a CIFAR-10 dataset from TFDS (as opposed to a custom TFR table).""" del is_training filenames = [ os.path.join(FLAGS.cifar_dir, f) for f in _CIFAR_FILENAMES[split] ] dataset = (tf.data.FixedLengthRecordDataset( filenames, _RECORD_BYTES).map(_parse_cifar)) return experiment_utils.download_dataset(dataset)
def _download_alt_dataset(config): dataset_to_dl = tfds.load(config.alt_dataset_name, split=config.split, as_supervised=True) return experiment_utils.download_dataset(dataset_to_dl)