def load_data(self, dataset_name, batch_size, use_debug_mode=False): assert dataset_name in ['celeba', 'binarized-mnist', 'church_outdoor'], "cannot find the dataset" self.data_set = dataset_name self.batch_size = batch_size assert self.batch_size % self.nr_gpu == 0, "Batch of data cannot be evenly distributed to {0} GPUs".format( self.nr_gpu) if dataset_name == 'celeba': data_dir = "/data/ziz/not-backed-up/datasets-ziz-all/processed_data/CelebA" data_set = load_data.CelebA(data_dir=data_dir, batch_size=batch_size, img_size=self.img_size) self.num_channels = 3 self.vrange = [-1., 1.] elif dataset_name == 'binarized-mnist': # data_dir = "/data/ziz/not-backed-up/datasets-ziz-all/processed_data/mnist" data_dir = "/data/ziz/not-backed-up/jxu/mnist" data_set = load_data.BinarizedMNIST(data_dir=data_dir, batch_size=batch_size, img_size=self.img_size) self.num_channels = 1 self.vrange = [0, 1] elif dataset_name == 'church_outdoor': #data_dir = "/data/ziz/not-backed-up/datasets-ziz-all/raw_data/lsun/church_outdoor" data_dir = "/data/ziz/not-backed-up/jxu/church_outdoor" data_set = load_data.ChurchOutdoor(data_dir=data_dir, batch_size=batch_size, img_size=self.img_size) self.num_channels = 3 self.vrange = [-1., 1.] if use_debug_mode: self.train_set = data_set.train(shuffle=True, limit=batch_size * 2) self.eval_set = data_set.train(shuffle=True, limit=batch_size * 2) self.test_set = data_set.test(shuffle=False, limit=-1) else: self.train_set = data_set.train(shuffle=True, limit=-1) self.eval_set = data_set.train(shuffle=True, limit=batch_size * 10) self.test_set = data_set.test(shuffle=False, limit=-1)
args = parser.parse_args() if args.mode == 'test': args.debug = True args.nr_gpu = len(args.gpus.split(",")) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus print('input args:\n', json.dumps(vars(args), indent=4, separators=(',', ':'))) # pretty print args if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) tf.set_random_seed(args.seed) batch_size = args.batch_size * args.nr_gpu if 'celeba' in args.data_set: data_set = load_data.CelebA(data_dir=args.data_dir, batch_size=batch_size, img_size=args.img_size) elif 'svhn' in args.data_set: data_set = load_data.SVHN(data_dir=args.data_dir, batch_size=batch_size, img_size=args.img_size) if args.debug: train_data = data_set.train(shuffle=True, limit=batch_size * 2) eval_data = data_set.train(shuffle=True, limit=batch_size * 2) test_data = data_set.test(shuffle=False, limit=-1) else: train_data = data_set.train(shuffle=True, limit=-1) eval_data = data_set.train(shuffle=True, limit=batch_size * 10) test_data = data_set.test(shuffle=False, limit=-1) # masks