def __getitem__(self, index):

        index = index % self.size

        img1 = frame_utils.read_gen(self.image_list[index][0])
        img2 = frame_utils.read_gen(self.image_list[index][1])

        flow = frame_utils.read_gen(self.flow_list[index])

        images = [img1, img2]
        image_size = img1.shape[:2]

        if self.is_cropped:
            cropper = StaticRandomCrop(image_size, self.crop_size)
        else:
            cropper = StaticCenterCrop(image_size, self.render_size)
        images = list(map(cropper, images))
        flow = cropper(flow)

        images = np.array(images).transpose(3, 0, 1, 2)
        flow = flow.transpose(2, 0, 1)

        images = torch.from_numpy(images.astype(np.float32))
        flow = torch.from_numpy(flow.astype(np.float32))

        return [images], [flow]
    def __init__(self,
                 args,
                 is_cropped,
                 root="/path/to/frames/only/folder",
                 iext="png",
                 replicates=1):
        self.args = args
        self.is_cropped = is_cropped
        self.crop_size = args.crop_size
        self.render_size = args.inference_size
        self.replicates = replicates

        images = sorted(glob(join(root, "*." + iext)))
        self.image_list = []
        for i in range(len(images) - 1):
            im1 = images[i]
            im2 = images[i + 1]
            self.image_list += [[im1, im2]]

        self.size = len(self.image_list)

        self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape

        if ((self.render_size[0] < 0) or (self.render_size[1] < 0)
                or (self.frame_size[0] % 64) or (self.frame_size[1] % 64)):
            self.render_size[0] = ((self.frame_size[0]) // 64) * 64
            self.render_size[1] = ((self.frame_size[1]) // 64) * 64

        args.inference_size = self.render_size
    def __init__(self,
                 args,
                 is_cropped,
                 root="/path/to/FlyingChairs_release/data",
                 replicates=1):
        self.args = args
        self.is_cropped = is_cropped
        self.crop_size = args.crop_size
        self.render_size = args.inference_size
        self.replicates = replicates

        images = sorted(glob(join(root, "*.ppm")))

        self.flow_list = sorted(glob(join(root, "*.flo")))

        assert len(images) // 2 == len(self.flow_list)

        self.image_list = []
        for i in range(len(self.flow_list)):
            im1 = images[2 * i]
            im2 = images[2 * i + 1]
            self.image_list += [[im1, im2]]

        assert len(self.image_list) == len(self.flow_list)

        self.size = len(self.image_list)

        self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape

        if ((self.render_size[0] < 0) or (self.render_size[1] < 0)
                or (self.frame_size[0] % 64) or (self.frame_size[1] % 64)):
            self.render_size[0] = ((self.frame_size[0]) // 64) * 64
            self.render_size[1] = ((self.frame_size[1]) // 64) * 64

        args.inference_size = self.render_size
    def __init__(self,
                 args,
                 is_cropped=False,
                 root="",
                 dstype="clean",
                 replicates=1):
        self.args = args
        self.is_cropped = is_cropped
        self.crop_size = args.crop_size
        self.render_size = args.inference_size
        self.replicates = replicates

        flow_root = join(root, "flow")
        image_root = join(root, dstype)

        file_list = sorted(glob(join(flow_root, "*/*.flo")))

        self.flow_list = []
        self.image_list = []

        for file in file_list:
            if "test" in file:
                # print file
                continue

            fbase = file[len(flow_root) + 1:]
            fprefix = fbase[:-8]
            fnum = int(fbase[-8:-4])

            img1 = join(image_root, fprefix + "%04d" % (fnum + 0) + ".png")
            img2 = join(image_root, fprefix + "%04d" % (fnum + 1) + ".png")

            if not isfile(img1) or not isfile(img2) or not isfile(file):
                continue

            self.image_list += [[img1, img2]]
            self.flow_list += [file]

        self.size = len(self.image_list)

        self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape

        if ((self.render_size[0] < 0) or (self.render_size[1] < 0)
                or (self.frame_size[0] % 64) or (self.frame_size[1] % 64)):
            self.render_size[0] = ((self.frame_size[0]) // 64) * 64
            self.render_size[1] = ((self.frame_size[1]) // 64) * 64

        args.inference_size = self.render_size

        assert len(self.image_list) == len(self.flow_list)
    def __init__(self,
                 args,
                 is_cropped,
                 root="/path/to/flyingthings3d",
                 dstype="frames_cleanpass",
                 replicates=1):
        self.args = args
        self.is_cropped = is_cropped
        self.crop_size = args.crop_size
        self.render_size = args.inference_size
        self.replicates = replicates

        image_dirs = sorted(glob(join(root, dstype, "TRAIN/*/*")))
        image_dirs = sorted([join(f, "left") for f in image_dirs] +
                            [join(f, "right") for f in image_dirs])

        flow_dirs = sorted(
            glob(join(root, "optical_flow_flo_format/TRAIN/*/*")))
        flow_dirs = sorted([join(f, "into_future/left") for f in flow_dirs] +
                           [join(f, "into_future/right") for f in flow_dirs])

        assert len(image_dirs) == len(flow_dirs)

        self.image_list = []
        self.flow_list = []

        for idir, fdir in zip(image_dirs, flow_dirs):
            images = sorted(glob(join(idir, "*.png")))
            flows = sorted(glob(join(fdir, "*.flo")))
            for i in range(len(flows)):
                self.image_list += [[images[i], images[i + 1]]]
                self.flow_list += [flows[i]]

        assert len(self.image_list) == len(self.flow_list)

        self.size = len(self.image_list)

        self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape

        if ((self.render_size[0] < 0) or (self.render_size[1] < 0)
                or (self.frame_size[0] % 64) or (self.frame_size[1] % 64)):
            self.render_size[0] = ((self.frame_size[0]) // 64) * 64
            self.render_size[1] = ((self.frame_size[1]) // 64) * 64

        args.inference_size = self.render_size
    def __init__(self,
                 args,
                 is_cropped,
                 root="/path/to/chairssdhom/data",
                 dstype="train",
                 replicates=1):
        self.args = args
        self.is_cropped = is_cropped
        self.crop_size = args.crop_size
        self.render_size = args.inference_size
        self.replicates = replicates

        image1 = sorted(glob(join(root, dstype, "t0/*.png")))
        image2 = sorted(glob(join(root, dstype, "t1/*.png")))
        self.flow_list = sorted(glob(join(root, dstype, "flow/*.flo")))

        assert len(image1) == len(self.flow_list)

        self.image_list = []
        for i in range(len(self.flow_list)):
            im1 = image1[i]
            im2 = image2[i]
            self.image_list += [[im1, im2]]

        assert len(self.image_list) == len(self.flow_list)

        self.size = len(self.image_list)

        self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape

        if ((self.render_size[0] < 0) or (self.render_size[1] < 0)
                or (self.frame_size[0] % 64) or (self.frame_size[1] % 64)):
            self.render_size[0] = ((self.frame_size[0]) // 64) * 64
            self.render_size[1] = ((self.frame_size[1]) // 64) * 64

        args.inference_size = self.render_size