예제 #1
0
def make_celeba_dataset(img_dir,
                        label_path,
                        att_names,
                        batch_size,
                        load_size=286,
                        crop_size=256,
                        training=True,
                        drop_remainder=True,
                        shuffle=True,
                        repeat=1):
    img_names = np.genfromtxt(label_path, dtype=str, usecols=0)
    img_paths = np.array(
        [py.join(img_dir, img_name) for img_name in img_names])
    labels = np.genfromtxt(label_path, dtype=int, usecols=range(1, 13))
    labels = labels[:, np.array([ATT_ID[att_name] for att_name in att_names])]

    if shuffle:
        idx = np.random.permutation(len(img_paths))
        img_paths = img_paths[idx]
        labels = labels[idx]

    if training:

        def map_fn_(img, label):
            img = tf.image.resize(img, [load_size, load_size])
            # img = tl.random_rotate(img, 5)
            img = tf.image.random_flip_left_right(img)
            img = tf.image.random_crop(img, [crop_size, crop_size, 3])
            # img = tl.color_jitter(img, 25, 0.2, 0.2, 0.1)
            # img = tl.random_grayscale(img, p=0.3)
            img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
            label = (label + 1) // 2
            return img, label
    else:

        def map_fn_(img, label):
            img = tf.image.resize(img, [load_size, load_size])
            img = tl.center_crop(img, size=crop_size)
            img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
            label = (label + 1) // 2
            return img, label

    dataset = tl.disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          labels=labels,
                                          drop_remainder=drop_remainder,
                                          map_fn=map_fn_,
                                          shuffle=shuffle,
                                          repeat=repeat)

    if drop_remainder:
        len_dataset = len(img_paths) // batch_size
    else:
        len_dataset = int(np.ceil(len(img_paths) / batch_size))

    return dataset, len_dataset
예제 #2
0
def make_dataset(img_dir,
                 batch_size,
                 load_size=286,
                 crop_size=256,
                 n_channels=3,
                 training=True,
                 drop_remainder=True,
                 shuffle=True,
                 repeat=1):
    img_paths = sorted(py.glob(img_dir, '*'))

    if shuffle:
        img_paths = np.random.permutation(img_paths)

    if training:

        def _map_fn(img):
            if n_channels == 1:
                img = tf.image.rgb_to_grayscale(img)
            img = tf.image.resize(img, [load_size, load_size])
            img = tf.image.random_flip_left_right(img)
            img = tl.center_crop(img, size=crop_size)
            # img = tf.image.random_crop(img, [crop_size, crop_size, n_channels])
            img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
            return img
    else:

        def _map_fn(img):
            if n_channels == 1:
                img = tf.image.rgb_to_grayscale(img)
            img = tf.image.resize(img, [load_size, load_size])
            img = tl.center_crop(img, size=crop_size)
            img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
            return img

    dataset = tl.disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)

    if drop_remainder:
        len_dataset = len(img_paths) // batch_size
    else:
        len_dataset = int(np.ceil(len(img_paths) / batch_size))

    return dataset, len_dataset
예제 #3
0
    def __init__(self,
                 data_dir,
                 atts,
                 img_resize,
                 batch_size,
                 prefetch_batch=_N_CPU + 1,
                 drop_remainder=True,
                 num_threads=_N_CPU,
                 shuffle=True,
                 shuffle_buffer_size=None,
                 repeat=-1,
                 sess=None,
                 split='train',
                 crop=True):
        super(Celeba, self).__init__()

        list_file = os.path.join(data_dir, 'list_attr_celeba.txt')
        if crop:
            img_dir_jpg = os.path.join(data_dir, 'img_align_celeba')
            img_dir_png = os.path.join(data_dir, 'img_align_celeba_png')
        else:
            img_dir_jpg = os.path.join(data_dir, 'img_crop_celeba')
            img_dir_png = os.path.join(data_dir, 'img_crop_celeba_png')

        names = np.loadtxt(list_file, skiprows=2, usecols=[0], dtype=np.str)
        if os.path.exists(img_dir_png):
            img_paths = [
                os.path.join(img_dir_png, name.replace('jpg', 'png'))
                for name in names
            ]
        elif os.path.exists(img_dir_jpg):
            img_paths = [os.path.join(img_dir_jpg, name) for name in names]

        att_id = [Celeba.att_dict[att] + 1 for att in atts]
        labels = np.loadtxt(list_file,
                            skiprows=2,
                            usecols=att_id,
                            dtype=np.int64)

        if img_resize == 64:
            # crop as how VAE/GAN do
            offset_h = 40
            offset_w = 15
            img_size = 148
        else:
            offset_h = 26
            offset_w = 3
            img_size = 170

        def _map_func(img, label):
            if crop:
                img = tf.image.crop_to_bounding_box(img, offset_h, offset_w,
                                                    img_size, img_size)
            # img = tf.image.resize_images(img, [img_resize, img_resize]) / 127.5 - 1
            # or
            img = tf.image.resize_images(img, [img_resize, img_resize],
                                         tf.image.ResizeMethod.BICUBIC)
            img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
            label = (label + 1) // 2
            return img, label

        if split == 'test':
            drop_remainder = False
            shuffle = False
            repeat = 1
            img_paths = img_paths[182637:]
            labels = labels[182637:]
        elif split == 'val':
            img_paths = img_paths[182000:182637]
            labels = labels[182000:182637]
        else:
            img_paths = img_paths[:182000]
            labels = labels[:182000]

        dataset = tl.disk_image_batch_dataset(
            img_paths=img_paths,
            labels=labels,
            batch_size=batch_size,
            prefetch_batch=prefetch_batch,
            drop_remainder=drop_remainder,
            map_func=_map_func,
            num_threads=num_threads,
            shuffle=shuffle,
            shuffle_buffer_size=shuffle_buffer_size,
            repeat=repeat)
        self._bulid(dataset, sess)

        self._img_num = len(img_paths)