def make_celeba_dataset(img_dir, label_path, att_names, batch_size, load_size=286, crop_size=256, training=True, drop_remainder=True, shuffle=True, repeat=1): img_names = np.genfromtxt(label_path, dtype=str, usecols=0) img_paths = np.array( [py.join(img_dir, img_name) for img_name in img_names]) labels = np.genfromtxt(label_path, dtype=int, usecols=range(1, 13)) labels = labels[:, np.array([ATT_ID[att_name] for att_name in att_names])] if shuffle: idx = np.random.permutation(len(img_paths)) img_paths = img_paths[idx] labels = labels[idx] if training: def map_fn_(img, label): img = tf.image.resize(img, [load_size, load_size]) # img = tl.random_rotate(img, 5) img = tf.image.random_flip_left_right(img) img = tf.image.random_crop(img, [crop_size, crop_size, 3]) # img = tl.color_jitter(img, 25, 0.2, 0.2, 0.1) # img = tl.random_grayscale(img, p=0.3) img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 label = (label + 1) // 2 return img, label else: def map_fn_(img, label): img = tf.image.resize(img, [load_size, load_size]) img = tl.center_crop(img, size=crop_size) img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 label = (label + 1) // 2 return img, label dataset = tl.disk_image_batch_dataset(img_paths, batch_size, labels=labels, drop_remainder=drop_remainder, map_fn=map_fn_, shuffle=shuffle, repeat=repeat) if drop_remainder: len_dataset = len(img_paths) // batch_size else: len_dataset = int(np.ceil(len(img_paths) / batch_size)) return dataset, len_dataset
def make_dataset(img_dir, batch_size, load_size=286, crop_size=256, n_channels=3, training=True, drop_remainder=True, shuffle=True, repeat=1): img_paths = sorted(py.glob(img_dir, '*')) if shuffle: img_paths = np.random.permutation(img_paths) if training: def _map_fn(img): if n_channels == 1: img = tf.image.rgb_to_grayscale(img) img = tf.image.resize(img, [load_size, load_size]) img = tf.image.random_flip_left_right(img) img = tl.center_crop(img, size=crop_size) # img = tf.image.random_crop(img, [crop_size, crop_size, n_channels]) img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 return img else: def _map_fn(img): if n_channels == 1: img = tf.image.rgb_to_grayscale(img) img = tf.image.resize(img, [load_size, load_size]) img = tl.center_crop(img, size=crop_size) img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 return img dataset = tl.disk_image_batch_dataset(img_paths, batch_size, drop_remainder=drop_remainder, map_fn=_map_fn, shuffle=shuffle, repeat=repeat) if drop_remainder: len_dataset = len(img_paths) // batch_size else: len_dataset = int(np.ceil(len(img_paths) / batch_size)) return dataset, len_dataset
def __init__(self, data_dir, atts, img_resize, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, num_threads=_N_CPU, shuffle=True, shuffle_buffer_size=None, repeat=-1, sess=None, split='train', crop=True): super(Celeba, self).__init__() list_file = os.path.join(data_dir, 'list_attr_celeba.txt') if crop: img_dir_jpg = os.path.join(data_dir, 'img_align_celeba') img_dir_png = os.path.join(data_dir, 'img_align_celeba_png') else: img_dir_jpg = os.path.join(data_dir, 'img_crop_celeba') img_dir_png = os.path.join(data_dir, 'img_crop_celeba_png') names = np.loadtxt(list_file, skiprows=2, usecols=[0], dtype=np.str) if os.path.exists(img_dir_png): img_paths = [ os.path.join(img_dir_png, name.replace('jpg', 'png')) for name in names ] elif os.path.exists(img_dir_jpg): img_paths = [os.path.join(img_dir_jpg, name) for name in names] att_id = [Celeba.att_dict[att] + 1 for att in atts] labels = np.loadtxt(list_file, skiprows=2, usecols=att_id, dtype=np.int64) if img_resize == 64: # crop as how VAE/GAN do offset_h = 40 offset_w = 15 img_size = 148 else: offset_h = 26 offset_w = 3 img_size = 170 def _map_func(img, label): if crop: img = tf.image.crop_to_bounding_box(img, offset_h, offset_w, img_size, img_size) # img = tf.image.resize_images(img, [img_resize, img_resize]) / 127.5 - 1 # or img = tf.image.resize_images(img, [img_resize, img_resize], tf.image.ResizeMethod.BICUBIC) img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 label = (label + 1) // 2 return img, label if split == 'test': drop_remainder = False shuffle = False repeat = 1 img_paths = img_paths[182637:] labels = labels[182637:] elif split == 'val': img_paths = img_paths[182000:182637] labels = labels[182000:182637] else: img_paths = img_paths[:182000] labels = labels[:182000] dataset = tl.disk_image_batch_dataset( img_paths=img_paths, labels=labels, batch_size=batch_size, prefetch_batch=prefetch_batch, drop_remainder=drop_remainder, map_func=_map_func, num_threads=num_threads, shuffle=shuffle, shuffle_buffer_size=shuffle_buffer_size, repeat=repeat) self._bulid(dataset, sess) self._img_num = len(img_paths)