Exemple #1
0
def data_iterator_tiny_digits(digits, batch_size=64, shuffle=False, rng=None):
    def load_func(index):
        """Loading an image and its label"""
        img = digits.images[index]
        label = digits.target[index]
        return img[None], np.array([label]).astype(np.int32)
    return data_iterator_simple(load_func, digits.target.shape[0], batch_size, shuffle, rng, with_file_cache=False)
Exemple #2
0
def edges2shoes_data_iterator(img_path,
                              batch_size=1,
                              normalize_method=lambda x: (x - 127.5) / 127.5,
                              num_samples=-1):
    imgs = glob.glob("{}/*.jpg".format(img_path))

    if num_samples == -1:
        num_samples = len(imgs)
    else:
        logger.info(
            "Num. of data ({}) is used for debugging".format(num_samples))

    def load_func(i):
        img = scipy.misc.imread(imgs[i], mode="RGB")
        img = normalize_method(img)
        h, w, c = img.shape
        img_A = img[:, 0:w // 2, :].transpose((2, 0, 1))
        img_B = img[:, w // 2:, :].transpose((2, 0, 1))
        return img_A, img_B

    return data_iterator_simple(load_func,
                                num_samples,
                                batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False)
Exemple #3
0
def data_iterator_fewshot(img_path,
                          batch_size,
                          imsize=(256, 256),
                          num_samples=1000,
                          shuffle=True,
                          rng=None):
    imgs = glob.glob("{}/**/*.jpg".format(img_path), recursive=True)
    if num_samples == -1:
        num_samples = len(imgs)
    else:
        logger.info(
            "Num. of data ({}) is used for debugging".format(num_samples))

    def load_func(i):
        img = imread(imgs[i], num_channels=3)
        img = imresize(img, imsize).transpose(2, 0, 1)
        img = img / 255. * 2. - 1.
        return img, i

    return data_iterator_simple(load_func,
                                num_samples,
                                batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False,
                                with_memory_cache=False)
Exemple #4
0
def data_iterator_tiny_digits(digits, batch_size=64, shuffle=False, rng=None):
    def load_func(index):
        """Loading an image and its label"""
        img = digits.images[index]
        label = digits.target[index]
        return img[None], np.array([label]).astype(np.int32)
    return data_iterator_simple(load_func, digits.target.shape[0], batch_size, shuffle, rng, with_file_cache=False)
Exemple #5
0
def data_iterator_celeba(img_path,
                         batch_size,
                         imsize=(128, 128),
                         num_samples=100,
                         shuffle=True,
                         rng=None):
    imgs = glob.glob("{}/*.png".format(img_path))
    if num_samples == -1:
        num_samples = len(imgs)
    else:
        logger.info(
            "Num. of data ({}) is used for debugging".format(num_samples))

    def load_func(i):
        cx = 89
        cy = 121
        img = imread(imgs[i])
        img = img[cy - 64:cy + 64, cx - 64:cx + 64, :].transpose(2, 0,
                                                                 1) / 255.
        img = img * 2. - 1.
        return img, None

    return data_iterator_simple(load_func,
                                num_samples,
                                batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False)
def test_sliced_data_iterator(test_data_csv_png_10, num_of_slices, size,
                              batch_size, shuffle):
    def test_load_func(position):
        return np.full((1), position, dtype=np.float32)

    di = data_iterator_simple(test_load_func,
                              size,
                              batch_size,
                              shuffle=shuffle)

    import six
    if six.PY2:
        from fractions import gcd
    else:
        from math import gcd

    def lcm(a, b):
        return abs(a * b) / gcd(a, b) if a and b else 0

    max_epoch = lcm(batch_size, size) / size

    all_data = []
    for slice_pos in range(num_of_slices):
        sliced_di = di.slice(rng=None,
                             num_of_slices=num_of_slices,
                             slice_pos=slice_pos)
        sliced_data = {}
        while True:
            current_epoch = sliced_di.epoch
            if current_epoch > max_epoch + 1:
                break
            data = sliced_di.next()
            if current_epoch not in sliced_data:
                sliced_data[current_epoch] = []
            for dat in data:
                for d in dat:
                    sliced_data[current_epoch].append(d)
        all_data.append(sliced_data)

    epochs = {}
    for slice_pos, sliced_data in enumerate(all_data):
        for epoch in sorted(sliced_data.keys()):
            if epoch not in epochs:
                epochs[epoch] = []
            epochs[epoch].append(set(sliced_data[epoch]))

    for epoch in sorted(epochs.keys()):
        x0 = epochs[epoch][0]
        acceptable_size = batch_size
        amount = size // num_of_slices
        if acceptable_size < amount:
            acceptable_size = amount
        for dup in [x0 & x for x in epochs[epoch][1:]]:
            assert len(dup) < amount
def data_iterator_sr(num_examples,
                     batch_size,
                     gt_image,
                     lq_image,
                     train,
                     shuffle,
                     rng=None):
    from args import get_config
    conf = get_config()

    def dataset_load_func(i):
        # get images from the list
        scale = conf.train.scale
        gt_size = conf.train.gt_size
        gt_img = read_image(gt_image[i])
        lq_img = read_image(lq_image[i])
        if not train:
            gt_img = modcrop(gt_img, scale)
        gt_img = channel_convert(gt_img.shape[2], gt_img, color="RGB")
        if train:
            # randomly crop
            H, W, C = lq_img.shape
            lq_size = gt_size // scale
            rnd_h = random.randint(0, max(0, H - lq_size))
            rnd_w = random.randint(0, max(0, W - lq_size))
            lq_img = lq_img[rnd_h:rnd_h + lq_size, rnd_w:rnd_w + lq_size, :]
            rnd_h_gt, rnd_w_gt = int(rnd_h * scale), int(rnd_w * scale)
            gt_img = gt_img[rnd_h_gt:rnd_h_gt + gt_size,
                            rnd_w_gt:rnd_w_gt + gt_size, :]
            # horizontal and vertical flipping and rotation
            hflip, rot = [True, True]
            hflip = hflip and random.random() < 0.5
            vflip = rot and random.random() < 0.5
            rot90 = rot and random.random() < 0.5
            lq_img = augment(lq_img, hflip, rot90, vflip)
            gt_img = augment(gt_img, hflip, rot90, vflip)
            lq_img = channel_convert(C, [lq_img], color="RGB")[0]
        # BGR to RGB and HWC to CHW
        if gt_img.shape[2] == 3:
            gt_img = gt_img[:, :, [2, 1, 0]]
            lq_img = lq_img[:, :, [2, 1, 0]]

        gt_img = np.ascontiguousarray(np.transpose(gt_img, (2, 0, 1)))
        lq_img = np.ascontiguousarray(np.transpose(lq_img, (2, 0, 1)))
        return gt_img, lq_img

    return data_iterator_simple(dataset_load_func,
                                num_examples,
                                batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False,
                                with_memory_cache=False)
Exemple #8
0
def get_data_loader(attr_path,
                    image_dir,
                    batch_size,
                    batch_size_valid,
                    image_size,
                    attribute='Bangs'):
    dataset, attr2idx, idx2attr = get_data_dict(attr_path, [attribute])
    np.random.seed(313)
    np.random.shuffle(dataset)
    test_dataset = dataset[-4000:]

    training_dataset = dataset[:-4000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func,
                                  dataset=training_dataset,
                                  image_dir=image_dir,
                                  image_size=image_size,
                                  crop_size=image_size)
    data_iterator = data_iterator_simple(load_func,
                                         len(training_dataset),
                                         batch_size,
                                         with_file_cache=False,
                                         with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func,
                                       dataset=test_dataset,
                                       image_dir=image_dir,
                                       image_size=image_size,
                                       crop_size=image_size)
    test_data_iterator = data_iterator_simple(load_func_test,
                                              len(test_dataset),
                                              batch_size_valid,
                                              with_file_cache=False,
                                              with_memory_cache=False)

    return data_iterator, test_data_iterator
Exemple #9
0
def munit_data_iterator(img_path,
                        batch_size=1,
                        image_size=256,
                        num_samples=-1,
                        normalize_method=lambda x: (x - 127.5) / 127.5,
                        shuffle=True,
                        rng=None):
    imgs = []
    if type(img_path) == list:
        for p in img_path:
            imgs.append(p)
    elif os.path.isdir(img_path):
        imgs += glob.glob("{}/*.jpg".format(img_path))
        imgs += glob.glob("{}/*.JPG".format(img_path))
        imgs += glob.glob("{}/*.jpeg".format(img_path))
        imgs += glob.glob("{}/*.JPEG".format(img_path))
        imgs += glob.glob("{}/*.png".format(img_path))
        imgs += glob.glob("{}/*.PNG".format(img_path))
    elif img_path.endswith(".jpg") or img_path.endswith(".JPG") \
            or img_path.endswith(".jpeg") or img_path.endswith(".JPEG") \
            or img_path.endswith(".png") or img_path.endswith(".PNG"):
        imgs.append(img_path)
    else:
        raise ValueError(
            "Path specified is not `directory path` or `list of files`.")

    if num_samples == -1:
        num_samples = len(imgs)
    else:
        logger.info(
            "Num. of data ({}) is used for debugging".format(num_samples))

    def load_func(i):
        img = scipy.misc.imread(imgs[i], mode="RGB")
        img = scipy.misc.imresize(img, (image_size, image_size))
        img = normalize_method(img)
        img = img.transpose((2, 0, 1))
        return img, None

    return data_iterator_simple(load_func,
                                num_samples,
                                batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False)
def data_iterator_segmentation(num_examples,
                               batch_size,
                               image_path_file,
                               label_path_file,
                               rng=None,
                               target_width=513,
                               target_height=513,
                               train=True):

    image_paths = load_paths(image_path_file)
    label_paths = load_paths(label_path_file)

    def image_label_load_func(i):
        '''
        Returns:
            image: c x h x w array
            label: c x h x w array
            mask: c x h x w array
        '''

        img = cv2.imread(image_paths[i]).astype('float32')
        b, g, r = cv2.split(img)
        img = cv2.merge([r, g, b])
        if 'png' in label_paths[i]:
            lab = imageio.imread(label_paths[i], as_gray=False,
                                 pilmode="RGB").astype('int32')
        else:
            lab = np.load(label_paths[i], allow_pickle=True).astype('int32')
        if lab.ndim == 2:
            lab = lab[..., None]
        # Compute image preprocessing time
        #t = time.time()
        img, lab, mask = image_preprocess.preprocess_image_and_label(
            img, lab, target_width, target_height, train=train)
        #elapsed = time.time() - t

        return np.rollaxis(img, 2), np.rollaxis(lab, 2), np.rollaxis(mask, 2)

    return data_iterator_simple(image_label_load_func,
                                num_examples,
                                batch_size,
                                shuffle=True,
                                rng=rng,
                                with_file_cache=False,
                                with_memory_cache=False)
Exemple #11
0
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle):
    src_data = []
    with open(test_data_csv_png_10) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(
                os.path.dirname(test_data_csv_png_10), values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, [int(values[1])]))

    def test_load_func(position):
        return src_data[position]

    size = len(src_data)
    with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle) as di:
        check_data_iterator_result(di, batch_size, shuffle, False)
Exemple #12
0
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle,
                              stop_exhausted):
    src_data = []
    with open(test_data_csv_png_10) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(os.path.dirname(test_data_csv_png_10),
                                         values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, [int(values[1])]))

    def test_load_func(position):
        return src_data[position]

    def end_epoch(epoch):
        print(f"{epoch} == {expect_epoch[0]}")
        assert epoch == expect_epoch[0], "Failed for end epoch check"
        assert threading.current_thread(
        ).ident == main_thread, "Failed for thread checking"

    def begin_epoch(epoch):
        print(f"{epoch} == {expect_epoch[0]}")
        assert epoch == expect_epoch[0], "Failed for begin epoch check"
        assert threading.current_thread(
        ).ident == main_thread, "Failed for thread checking"

    size = len(src_data)
    main_thread = threading.current_thread().ident
    expect_epoch = [0]
    with data_iterator_simple(test_load_func,
                              size,
                              batch_size,
                              shuffle=shuffle,
                              stop_exhausted=stop_exhausted) as di:
        if batch_size // size == 0:
            di.register_epoch_end_callback(begin_epoch)
            di.register_epoch_end_callback(end_epoch)
        di.register_epoch_end_callback(begin_epoch)
        di.register_epoch_end_callback(end_epoch)
        check_data_iterator_result(di, batch_size, shuffle, False,
                                   stop_exhausted, expect_epoch)
Exemple #13
0
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle, stop_exhausted):
    src_data = []
    with open(test_data_csv_png_10) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(
                os.path.dirname(test_data_csv_png_10), values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, [int(values[1])]))

    def test_load_func(position):
        return src_data[position]

    size = len(src_data)
    with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle, stop_exhausted=stop_exhausted) as di:
        check_data_iterator_result(
            di, batch_size, shuffle, False, stop_exhausted)
Exemple #14
0
def data_iterator_celeba(img_path,
                         attributes,
                         transform=None,
                         batch_size=32,
                         num_samples=-1,
                         shuffle=True,
                         rng=None):
    """
    create celebA data iterator
    Args:
        img_path(list) : list of image paths
        attributes (dict) : attribute list
        transform : transform the image(data augmentation)
        batch_size (int) :  number of samples contained in each generated batch
        num_samples (int) : number of samples taken in data loader
                            (if num_samples=-1, it will take all the images in the dataset)
        shuffle (bool) : shuffle the data
    Returns:
        simple data iterator
    """
    imgs = img_path
    attr = attributes
    if num_samples == -1:
        num_samples = len(imgs)
    else:
        logger.info(
            "Num. of data ({}) is used for debugging".format(num_samples))

    def load_func(i):
        pillow_image = Image.open(imgs[i])
        image = np.array(pillow_image)
        transformed_image = transform(image=image)['image'].transpose(2, 0, 1)
        return transformed_image, attr[imgs[i]]

    return data_iterator_simple(load_func,
                                num_samples,
                                batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False)
Exemple #15
0
def data_iterator_segmentation(batch_size,
                               image_paths,
                               label_paths,
                               rng=None,
                               train=True):
    '''
    Returns a data iterator object for semantic image segmentation dataset.

    Args:
        batch_size (int): Batch size
        image_paths (list of str): A list of image paths
        label_paths (list of str): A list of label image paths
        rng (None or numpy.random.RandomState):
            A random number generator used in shuffling dataset and data augmentation.
        train (bool): It performs random data augmentation as preprocessing if train is True.
        num_classs (int): Number of classes. Requierd if `label_mask_transformer` is not passed.
    '''
    assert len(image_paths) == len(label_paths)
    num_examples = len(image_paths)

    def image_label_load_func(i):
        '''
        Returns:
            image: c x h x w array
            label: c x h x w array
        '''
        img = cv2.imread(image_paths[i], cv2.IMREAD_COLOR)
        lab = palette_png_reader(label_paths[i])
        img, lab = image_preprocess.preprocess_image_and_label(img,
                                                               lab,
                                                               rng=rng)
        return img, lab

    return data_iterator_simple(image_label_load_func,
                                num_examples,
                                batch_size,
                                shuffle=train,
                                rng=rng)
Exemple #16
0
num_train_batch = len(x_train) // batch_size
num_dev_batch = len(x_test) // batch_size


def load_train_func(index):
    return x_train[index], y_train[index]


def load_dev_func(index):
    return x_test[index], y_test[index]


train_data_iter = data_iterator_simple(load_train_func,
                                       len(x_train),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)
dev_data_iter = data_iterator_simple(load_dev_func,
                                     len(x_test),
                                     batch_size,
                                     shuffle=True,
                                     with_file_cache=False)


def build_self_attention_model(train=True):
    x = nn.Variable((batch_size, max_len))
    t = nn.Variable((batch_size, 1))
    mask = get_mask(x)
    attention_mask = (F.constant(1, shape=mask.shape) - mask) * F.constant(
        np.finfo(np.float32).min, shape=mask.shape)
num_train_batch = len(x_train) // batch_size
num_valid_batch = len(x_valid) // batch_size


def load_train_func(index):
    return x_train[index], y_train[index]


def load_valid_func(index):
    return x_valid[index], y_valid[index]


train_data_iter = data_iterator_simple(load_train_func,
                                       len(x_train),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)
valid_data_iter = data_iterator_simple(load_valid_func,
                                       len(x_valid),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)

x = nn.Variable((batch_size, sentence_length))
t = nn.Variable((batch_size, sentence_length, 1))
h = PF.embed(x, vocab_size, embedding_size)
h = LSTM(h, hidden, return_sequences=True)
h = TimeDistributed(PF.affine)(h, hidden, name='hidden')
y = TimeDistributed(PF.affine)(h, vocab_size, name='output')
def test_sliced_data_iterator_equivalence(test_data_csv_png_10, num_of_slices, size, batch_size, shuffle):

    def lcm(a, b):
        return abs(a * b) / math.gcd(a, b) if a and b else 0

    max_epoch = lcm(batch_size, size) / size

    def test_load_func(position):
        return np.full((1), position, dtype=np.int)

    def simple_load_func(data_set, position):
        return data_set[position]

    def get_data(iter_list, iter_num):
        total = 0
        for it in iter_list:
            for _ in range(iter_num):
                yield it.next()
                total += 1
            yield total
        yield total

    iter_num = int((max_epoch * size) / (num_of_slices * batch_size) + 0.5)

    sliced_di_list = []
    di = data_iterator_simple(test_load_func, size,
                              batch_size, shuffle=shuffle)

    for slice_pos in range(num_of_slices):
        sliced_di = di.slice(
            rng=None, num_of_slices=num_of_slices, slice_pos=slice_pos)
        sliced_di_list.append(sliced_di)

    ref_di_list = []
    all_data = [np.full((1), position, dtype=np.int)
                for position in range(size)]
    for slice_pos in range(num_of_slices):
        slice_sample_size = size / num_of_slices
        start_index = int(slice_sample_size * slice_pos + 0.5)
        end_index = int(slice_sample_size * (slice_pos + 1) + 0.5)
        slice_block_size = end_index - start_index
        sliced_data = all_data[start_index: end_index]
        di = data_iterator_simple(
            partial(simple_load_func, sliced_data), slice_block_size, batch_size, shuffle=shuffle)
        ref_di_list.append(di)

    set_a = set()
    set_b = set()
    for ref, t in zip(get_data(ref_di_list, iter_num), get_data(sliced_di_list, iter_num)):
        if isinstance(ref, tuple):
            ref, t = ref[0], t[0]
        if isinstance(ref, np.ndarray):
            # print(f"{ref} <--> {t}")
            set_a = set_a.union(set(ref))
            set_b = set_b.union(set(t))
        else:
            #print("-" * 30)
            assert ref == t
    # str_a = ','.join([str(f) for f in set_a])
    # str_b = ','.join([str(f) for f in set_b])
    # print(f"{str_a}  <--> {str_b}")
    assert set_a == set_b

    di_all = ref_di_list + sliced_di_list
    for di in di_all:
        di.close()
Exemple #19
0
def load_train_func(index):
    x, y = pdata[index]
    negative_sample_prob = np.ones(len(pdict))
    negative_sample_prob[pdict[x]] = 0.0
    negative_sample_prob[pdict[y]] = 0.0
    negative_sample_prob /= len(pdict) - 2
    negative_sample_indices = np.random.choice(range(len(pdict)),
                                               negative_sample_size,
                                               replace=False,
                                               p=negative_sample_prob)
    return pdict[x], pdict[y], negative_sample_indices


train_data_iter = data_iterator_simple(load_train_func,
                                       len(pdata),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)
"""
"""


def distance(u, v, eps=1e-5):
    uu = F.sum(F.pow_scalar(u, 2), axis=1)
    vv = F.sum(F.pow_scalar(v, 2), axis=1)
    euclid_norm_pow2 = F.sum(F.pow_scalar(u - v, 2), axis=1)
    alpha = F.maximum2(F.constant(eps, shape=uu.shape), 1.0 - uu)
    beta = F.maximum2(F.constant(eps, shape=vv.shape), 1.0 - vv)

    return F.acosh(1 + 2 * euclid_norm_pow2 / (alpha * beta))
Exemple #20
0
def data_iterator_imagenet(img_path,
                           dirname_to_label_path,
                           batch_size=16,
                           ih=128,
                           iw=128,
                           n_classes=1000,
                           class_id=-1,
                           noise=True,
                           normalize=lambda x: x / 128.0 - 1.0,
                           train=True,
                           shuffle=True,
                           rng=None):
    # ------
    # Valid
    # ------
    if not train:
        # Classes (but this tmpdir in ImageNet case)
        dir_paths = glob.glob("{}/*".format(img_path))
        dir_paths.sort()
        dir_paths = dir_paths[0:n_classes]

        # Images
        imgs = []
        for dir_path in dir_paths:
            imgs += glob.glob("{}/*.JPEG".format(dir_path))

        def load_func(i):
            # image
            img = Image.open(imgs[i]).resize((iw, ih),
                                             Image.BILINEAR).convert("RGB")
            img = np.asarray(img)
            img = img.transpose((2, 0, 1))
            img = img / 128.0 - 1.0
            return img, np.array([])

        di = data_iterator_simple(load_func,
                                  len(imgs),
                                  batch_size,
                                  shuffle=shuffle,
                                  rng=rng,
                                  with_file_cache=False)
        return di

    # ------
    # Train
    # ------
    # Classes
    dir_paths = glob.glob("{}/*".format(img_path))
    dir_paths.sort()
    dir_paths = dir_paths[0:n_classes]

    # Images
    imgs = []
    for dir_path in dir_paths:
        imgs += glob.glob("{}/*.JPEG".format(dir_path))
    # np.random.shuffle(imgs)

    # Dirname to Label map
    dirname_to_label, label_to_dirname = create_dirname_label_maps(
        dirname_to_label_path)

    # Filter by class_id
    if class_id != -1:
        dirname = label_to_dirname[class_id]
        imgs = list(filter(lambda img: dirname in img, imgs))

    def load_func(i):
        # image
        img = Image.open(imgs[i]).resize((iw, ih),
                                         Image.BILINEAR).convert("RGB")
        img = np.asarray(img)
        img = img.transpose((2, 0, 1))
        img = img / 128.0 - 1.0
        if noise:
            img += np.random.uniform(size=img.shape, low=0.0, high=1.0 / 128)
        # label
        elms = imgs[i].rstrip().split("/")
        dname = elms[-2]
        label = dirname_to_label[dname]
        return img, np.array(label)

    di = data_iterator_simple(load_func,
                              len(imgs),
                              batch_size,
                              shuffle=shuffle,
                              rng=rng,
                              with_file_cache=False)
    return di
Exemple #21
0
def data_iterator(num_examples, batch_size, img_left, img_right, img_disp, train, shuffle, dataset, rng=None):
    def dataset_load_func(i):
        # get images from the list
        image_left = imread(img_left[i]).astype('float32')
        image_right = imread(img_right[i]).astype('float32')
        # print(img_disp)
        if dataset == "SceneFlow":
            from main import parser
            args = parser.parse_args()
            image_disp, scale = readPFM(img_disp[i])
            image_disp = np.ascontiguousarray(image_disp, dtype=np.float32)
        elif dataset == "Kitti":
            from finetune import parser
            args = parser.parse_args()
            image_disp = imread(img_disp[i]).astype('float32')
        image_disp = image_disp.reshape(
            image_disp.shape[0], image_disp.shape[1], 1)

        mean_imagenet = np.asarray([0.485, 0.456, 0.406]).astype(
            np.float32).reshape(3, 1, 1)
        std_imagenet = np.asarray([0.229, 0.224, 0.225]).astype(
            np.float32).reshape(3, 1, 1)

        if train:
            w, h = image_left.shape[1], image_left.shape[0]
            th, tw = args.crop_height, args.crop_width
            x1 = random.randint(0, w - tw)
            y1 = random.randint(0, h - th)
            # crop
            image_left = image_left[y1:y1 + th, x1:x1 + tw]
            image_right = image_right[y1:y1 + th, x1:x1 + tw]
            if dataset == "Kitti":
                image_disp = np.ascontiguousarray(
                    image_disp, dtype=np.float32)/256
            image_disp = image_disp[y1:y1 + th, x1:x1 + tw]
            # normalize with mean and std
            image_left, image_right, image_disp = np.rollaxis(
                image_left, 2), np.rollaxis(image_right, 2), np.rollaxis(image_disp, 2)
            image_left = (image_left/255).astype(np.float32)
            image_right = (image_right/255).astype(np.float32)
            image_left -= mean_imagenet
            image_left /= std_imagenet
            image_right -= mean_imagenet
            image_right /= std_imagenet
        else:
            # crop
            if dataset == "SceneFlow":
                image_left = image_left[:args.im_height, :args.im_width, :]
                image_right = image_right[:args.im_height, :args.im_width, :]
                image_disp = image_disp[:args.im_height, :args.im_width, :]
            elif dataset == "Kitti":
                w, h = image_left.shape[1], image_left.shape[0]
                image_left = image_left[h -
                                        args.im_height:h, w-args.im_width:w, :]
                image_right = image_right[h -
                                          args.im_height:h, w-args.im_width:w, :]
                image_disp = image_disp[h -
                                        args.im_height:h, w-args.im_width:w, :]
                image_disp = np.ascontiguousarray(
                    image_disp, dtype=np.float32)/256
            # normalize
            image_left, image_right, image_disp = np.rollaxis(
                image_left, 2), np.rollaxis(image_right, 2), np.rollaxis(image_disp, 2)
            image_left = (image_left/255).astype(np.float32)
            image_right = (image_right/255).astype(np.float32)
            image_left -= mean_imagenet
            image_left /= std_imagenet
            image_right -= mean_imagenet
            image_right /= std_imagenet

        return image_left, image_right, image_disp
    return data_iterator_simple(dataset_load_func, num_examples, batch_size, shuffle=shuffle, rng=rng,
                                with_file_cache=False, with_memory_cache=False)
Exemple #22
0
def main():
    """
    Main script.

    Steps:
    * Get and set context.
    * Load Dataset
    * Initialize DataIterator.
    * Create Networks
    *   Net for Labeled Data
    *   Net for Unlabeled Data
    *   Net for Test Data
    * Create Solver.
    * Training Loop.
    *   Test
    *   Training
    *     by Labeled Data
    *       Calculate Supervised Loss
    *     by Unlabeled Data
    *       Calculate Virtual Adversarial Noise
    *       Calculate Unsupervised Loss
    """

    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    shape_x = (1, 28, 28)
    n_h = args.n_units
    n_y = args.n_class

    # Load MNIST Dataset
    from mnist_data import load_mnist, data_iterator_mnist
    images, labels = load_mnist(train=True)
    rng = np.random.RandomState(706)
    inds = rng.permutation(len(images))

    def feed_labeled(i):
        j = inds[i]
        return images[j], labels[j]

    def feed_unlabeled(i):
        j = inds[i]
        return images[j], labels[j]

    di_l = data_iterator_simple(feed_labeled,
                                args.n_labeled,
                                args.batchsize_l,
                                shuffle=True,
                                rng=rng,
                                with_file_cache=False)
    di_u = data_iterator_simple(feed_unlabeled,
                                args.n_train,
                                args.batchsize_u,
                                shuffle=True,
                                rng=rng,
                                with_file_cache=False)
    di_v = data_iterator_mnist(args.batchsize_v, train=False)

    # Create networks
    # feed-forward-net building function
    def forward(x, test=False):
        return mlp_net(x, n_h, n_y, test)

    # Net for learning labeled data
    xl = nn.Variable((args.batchsize_l, ) + shape_x, need_grad=False)
    yl = forward(xl, test=False)
    tl = nn.Variable((args.batchsize_l, 1), need_grad=False)
    loss_l = F.mean(F.softmax_cross_entropy(yl, tl))

    # Net for learning unlabeled data
    xu = nn.Variable((args.batchsize_u, ) + shape_x, need_grad=False)
    yu = forward(xu, test=False)
    y1 = yu.get_unlinked_variable()
    y1.need_grad = False

    noise = nn.Variable((args.batchsize_u, ) + shape_x, need_grad=True)
    r = noise / (F.sum(noise**2, [1, 2, 3], keepdims=True))**0.5
    r.persistent = True
    y2 = forward(xu + args.xi_for_vat * r, test=False)
    y3 = forward(xu + args.eps_for_vat * r, test=False)
    loss_k = F.mean(distance(y1, y2))
    loss_u = F.mean(distance(y1, y3))

    # Net for evaluating validation data
    xv = nn.Variable((args.batchsize_v, ) + shape_x, need_grad=False)
    hv = forward(xv, test=True)
    tv = nn.Variable((args.batchsize_v, 1), need_grad=False)
    err = F.mean(F.top_n_error(hv, tv, n=1))

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor training and validation stats.
    import nnabla.monitor as M
    monitor = M.Monitor(args.model_save_path)
    monitor_verr = M.MonitorSeries("Test error", monitor, interval=240)
    monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=240)

    # Training Loop.
    t0 = time.time()

    for i in range(args.max_iter):

        # Validation Test
        if i % args.val_interval == 0:
            valid_error = calc_validation_error(di_v, xv, tv, err,
                                                args.val_iter)
            monitor_verr.add(i, valid_error)

        #################################
        ## Training by Labeled Data #####
        #################################

        # forward, backward and update
        xl.d, tl.d = di_l.next()
        xl.d = xl.d / 255
        solver.zero_grad()
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        #################################
        ## Training by Unlabeled Data ###
        #################################

        # Calculate y without noise, only once.
        xu.d, _ = di_u.next()
        xu.d = xu.d / 255
        yu.forward(clear_buffer=True)

        ##### Calculate Adversarial Noise #####
        # Do power method iteration
        noise.d = np.random.normal(size=xu.shape).astype(np.float32)
        for k in range(args.n_iter_for_power_method):
            r.grad.zero()
            loss_k.forward(clear_no_need_grad=True)
            loss_k.backward(clear_buffer=True)
            noise.data.copy_from(r.grad)

        ##### Calculate loss for unlabeled data #####
        # forward, backward and update
        solver.zero_grad()
        loss_u.forward(clear_no_need_grad=True)
        loss_u.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        ##### Learning rate update #####
        if i % args.iter_per_epoch == 0:
            solver.set_learning_rate(solver.learning_rate() *
                                     args.learning_rate_decay)
        monitor_time.add(i)

    # Evaluate the final model by the error rate with validation dataset
    valid_error = calc_validation_error(di_v, xv, tv, err, args.val_iter)
    monitor_verr.add(i, valid_error)
    monitor_time.add(i)

    # Save the model.
    parameter_file = os.path.join(args.model_save_path,
                                  'params_%06d.h5' % args.max_iter)
    nn.save_parameters(parameter_file)
Exemple #23
0
def data_iterator_sr(conf,
                     num_samples,
                     sample_names,
                     tar_size,
                     shuffle,
                     rng=None):
    """
    Data iterator for TecoGAN training
    return: makes provision for low res & high res frames in RNN segments for specified batch_size
    """
    def populate_hr_data(i):

        hr_data = []  # high res rgb, in range 0-1, shape any

        # moving first frame -> data augmentation
        # our data augmentation, moving first frame to mimic camera motion
        if conf.train.movingFirstFrame:
            lefttop_pos, range_pos, is_move = moving_decision(conf)

        for f_i in range(conf.train.rnn_n):
            img_name = sample_names[f_i][i]
            img_data = cv.imread(img_name, 3).astype(np.float32)
            img_data = img_data / 255

            if conf.train.movingFirstFrame:
                if f_i == 0:
                    img_data_0 = img_data
                    target_size = img_data.shape

                # random data augmentation -> move first frame only with 30% probability
                if not is_move < 0.7:
                    img_data = img_data_0[lefttop_pos[f_i][1]:target_size[0] -
                                          range_pos[1] + lefttop_pos[f_i][1],
                                          lefttop_pos[f_i][0]:target_size[1] -
                                          range_pos[0] +
                                          lefttop_pos[f_i][0], :]
            hr_data.append(img_data)

        return hr_data

    def dataset_load_func(i):

        hr_data = populate_hr_data(i)

        # random crop each batch entry separately
        # Check whether perform crop
        if conf.train.random_crop is True:
            cur_size = hr_data[0].shape
            offset_h = np.floor(
                np.random.uniform(0, cur_size[0] - tar_size, [])).astype(int)
            offset_w = np.floor(
                np.random.uniform(0, cur_size[1] - tar_size, [])).astype(int)
            for frame_t in range(conf.train.rnn_n):
                hr_data[frame_t] = hr_data[frame_t][offset_h:offset_h +
                                                    tar_size,
                                                    offset_w:offset_w +
                                                    tar_size, :]

        # random flip:
        if conf.train.flip is True:
            # Produce the decision of random flip
            flip_decision = np.random.uniform(0, 1, []).astype(float)
            for frame_t in range(conf.train.rnn_n):
                if flip_decision < 0.5:
                    np.fliplr(hr_data[frame_t])

        hr_frames = hr_data
        target_frames = []

        k_w_border = int(1.5 * 3.0)
        for rnn_inst in range(conf.train.rnn_n):
            # crop out desired data
            cropped_data = hr_data[rnn_inst][k_w_border:k_w_border +
                                             conf.train.crop_size * 4,
                                             k_w_border:k_w_border +
                                             conf.train.crop_size * 4, :]
            pre_processed_data = preprocess(cropped_data)
            target_frames.append(pre_processed_data)

        return hr_frames, target_frames

    return data_iterator_simple(dataset_load_func,
                                num_samples,
                                conf.train.batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False,
                                with_memory_cache=False)
Exemple #24
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
Exemple #25
0
def data_iterator(conf, shuffle, rng=None):
    """
    Data iterator for Zooming SloMo training
    return:
    """
    assert conf.data.n_frames > 1, 'Error: Not enough LR frames to interpolate'
    half_n_frames = conf.data.n_frames // 2

    # determine the LQ frame list
    # N | frames
    # 1 | error
    # 3 | 0,2
    # 5 | 0,2,4
    # 7 | 0,2,4,6

    lr_index_list = [i * 2 for i in range(1 + half_n_frames)]

    paths_gt = pickle.load(open(conf.data.cache_keys, 'rb'))

    gt_lmdb = lmdb.open(conf.data.lmdb_data_gt,
                        readonly=True,
                        lock=False,
                        readahead=False,
                        meminit=False)
    lq_lmdb = lmdb.open(conf.data.lmdb_data_lq,
                        readonly=True,
                        lock=False,
                        readahead=False,
                        meminit=False)

    center_frame_idx = random.randint(2, 6)  # 2<= index <=6

    def determine_neighbor_list(central_frame_idx):
        """
        given central frame index, determine neighborhood frames
        """
        interval = random.choice(conf.data.interval_list)

        if conf.data.border_mode:
            direction = 1  # 1: forward; 0: backward
            if conf.random_reverse and random.random() < 0.5:
                direction = random.choice([0, 1])
            if central_frame_idx + interval * (conf.data.n_frames - 1) > 7:
                direction = 0
            elif central_frame_idx - interval * (conf.data.n_frames - 1) < 1:
                direction = 1
            # get the neighbor list
            if direction == 1:
                neighbor_list = list(
                    range(central_frame_idx,
                          central_frame_idx + interval * conf.data.n_frames,
                          interval))
            else:
                neighbor_list = list(
                    range(central_frame_idx,
                          central_frame_idx - interval * conf.data.n_frames,
                          -interval))
        else:
            # ensure not exceeding the borders
            while (central_frame_idx + half_n_frames * interval > 7) or \
                    (central_frame_idx - half_n_frames * interval < 1):
                central_frame_idx = random.randint(2, 6)

            # get the neighbor list
            neighbor_list = list(
                range(central_frame_idx - half_n_frames * interval,
                      central_frame_idx + half_n_frames * interval + 1,
                      interval))
            if conf.data.random_reverse and random.random() < 0.5:
                neighbor_list.reverse()

        return neighbor_list

    neighbors = determine_neighbor_list(center_frame_idx)
    lq_frames_list = [neighbors[i] for i in lr_index_list]

    assert len(neighbors) == conf.data.n_frames, \
        'Wrong length of neighbor list: {}'.format(len(neighbors))

    # image read and augment functions

    def augment(img_list, flip=True, rot=True):
        # flip OR rotate
        def _augment(img):
            if flip and random.random() < 0.5:
                # horizontal flip
                img = img[:, ::-1, :]
            if rot and random.random() < 0.5:
                # vertical flip and 90 degree rotation
                img = img[::-1, :, :]
                img = img.transpose(1, 0, 2)
            return img

        return [_augment(img) for img in img_list]

    def _read_img_from_lmdb(env, key, size):
        """
        read image from lmdb with key (w/ and w/o fixed size)
        size: (channels, height, width) tuple
        """
        with env.begin(write=False) as txn:
            buf = txn.get(key.encode('ascii'))
        img_flat = np.frombuffer(buf, dtype=np.uint8)
        channels, height, width = size
        img = img_flat.reshape(height, width, channels)
        img = img.astype(np.float32) / 255.
        if img.ndim == 2:
            img = np.expand_dims(img, axis=2)
        # some images have 4 channels
        if img.shape[2] > 3:
            img = img[:, :, :3]
        return img

    def load_zoomingslomo_data(i):
        """
        loads data, given the index -> primary function in data loader
        """
        key = paths_gt[i]

        # get the GT image (as the center frame)
        img_gt_l = [
            _read_img_from_lmdb(gt_lmdb, key + '_{}'.format(v), (3, 256, 448))
            for v in neighbors
        ]

        # get Low Quality images
        lq_size_tuple = (3, 64, 112)
        img_lq_l = [
            _read_img_from_lmdb(lq_lmdb, key + '_{}'.format(v), lq_size_tuple)
            for v in lq_frames_list
        ]

        _, height, width = lq_size_tuple  # LQ size
        # randomly crop
        scale = 4
        gt_size = conf.data.gt_size
        lr_size = gt_size // scale
        rnd_h = random.randint(0, max(0, height - lr_size))
        rnd_w = random.randint(0, max(0, width - lr_size))
        img_lq_l = [
            v[rnd_h:rnd_h + lr_size, rnd_w:rnd_w + lr_size, :]
            for v in img_lq_l
        ]
        rnd_h_highres, rnd_w_highres = int(rnd_h * scale), int(rnd_w * scale)
        img_gt_l = [
            v[rnd_h_highres:rnd_h_highres + gt_size,
              rnd_w_highres:rnd_w_highres + gt_size, :] for v in img_gt_l
        ]

        # augmentation - flip, rotate
        img_lq_l = img_lq_l + img_gt_l
        rlt = augment(img_lq_l, conf.data.use_flip, conf.data.use_rot)
        img_lq_l = rlt[0:-conf.data.n_frames]
        img_gt_l = rlt[-conf.data.n_frames:]

        # stack LQ and GT images in NHWC order, N is the frame number
        img_lq_stack = np.stack(img_lq_l, axis=0)
        img_gt_stack = np.stack(img_gt_l, axis=0)

        # numpy to tensor
        img_gt_stack = img_gt_stack[:, :, :, [2, 1, 0]]  # BGR to RGB
        img_lq_stack = img_lq_stack[:, :, :, [2, 1, 0]]  # BGR to RGB
        img_gt_stack = np.ascontiguousarray(
            np.transpose(img_gt_stack, (0, 3, 1, 2)))  # HWC to CHW
        img_lq_stack = np.ascontiguousarray(
            np.transpose(img_lq_stack, (0, 3, 1, 2)))  # HWC to CHW

        return img_lq_stack, img_gt_stack

    def load_slomo_data(i):
        """
        loads data, given the index -> primary function in data loader
        """
        key = paths_gt[i]

        gt_size_tuple = (3, 256, 448)
        # get the GT image (as the center frame)
        img_gt_l = [
            _read_img_from_lmdb(gt_lmdb, key + '_{}'.format(v), gt_size_tuple)
            for v in neighbors
        ]

        _, height, width = gt_size_tuple  # GT size
        # randomly crop
        gt_size = conf.data.gt_size
        rnd_h = random.randint(0, max(0, height - gt_size))
        rnd_w = random.randint(0, max(0, width - gt_size))

        img_gt_l = [
            v[rnd_h:rnd_h + gt_size, rnd_w:rnd_w + gt_size, :]
            for v in img_gt_l
        ]

        # augmentation - flip, rotate
        img_gt_l = augment(img_gt_l, conf.data.use_flip, conf.data.use_rot)

        # stack LQ and GT images in NHWC order, N is the frame number
        img_gt_stack = np.stack(img_gt_l, axis=0)
        # numpy to tensor
        img_gt_stack = img_gt_stack[:, :, :, [2, 1, 0]]  # BGR to RGB
        img_gt_stack = np.ascontiguousarray(
            np.transpose(img_gt_stack, (0, 3, 1, 2)))  # HWC to CHW

        return _, img_gt_stack

    dataset_load_func = load_zoomingslomo_data if not conf.train.only_slomo else load_slomo_data

    return data_iterator_simple(dataset_load_func,
                                len(paths_gt),
                                conf.train.batch_size,
                                shuffle=shuffle,
                                rng=rng,
                                with_file_cache=False,
                                with_memory_cache=False)