def load_dataset(config): """Load dataset following instruction in `config`.""" if dataset_is_mnist_family(config['dataset']): crop_width = config.get('crop_width', None) # unused img_width = config.get('img_width', None) # unused scratch = config.get('scratch', get_default_scratch()) basepath = os.path.join(scratch, config['dataset'].lower()) data_path = os.path.join(basepath, 'data') save_path = os.path.join(basepath, 'ckpts') tf.gfile.MakeDirs(data_path) tf.gfile.MakeDirs(save_path) # black-on-white MNIST (harder to learn than white-on-black MNIST) # Running locally (pre-download data locally) mnist_train, mnist_eval, mnist_test = local_mnist.read_data_sets( data_path, one_hot=True) train_data = np.concatenate([mnist_train.images, mnist_eval.images], axis=0) attr_train = np.concatenate([mnist_train.labels, mnist_eval.labels], axis=0) eval_data = mnist_test.images attr_eval = mnist_test.labels attribute_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] elif config['dataset'] == 'CELEBA': crop_width = config['crop_width'] img_width = config['img_width'] postfix = '_crop_%d_res_%d.npy' % (crop_width, img_width) # Load Data scratch = config.get('scratch', get_default_scratch()) basepath = os.path.join(scratch, 'celeba') data_path = os.path.join(basepath, 'data') save_path = os.path.join(basepath, 'ckpts') (train_data, eval_data, _, attr_train, attr_eval, _, attribute_names) = _load_celeba(data_path, postfix) else: raise NotImplementedError return ObjectBlob( crop_width=crop_width, img_width=img_width, basepath=basepath, data_path=data_path, save_path=save_path, train_data=train_data, attr_train=attr_train, eval_data=eval_data, attr_eval=attr_eval, attribute_names=attribute_names, )