예제 #1
0
def prepare_CEILNET():
    datadir = '/data1/kangfu/haoyu/Data/Reflection/testdata_CEILNET_table2_raw/'
    transmission_layer_dir = '/data1/kangfu/haoyu/Data/Reflection/testdata_CEILNET_table2/transmission_layer/'
    blended_dir = '/data1/kangfu/haoyu/Data/Reflection/testdata_CEILNET_table2/blended/'
    try:
        os.mkdir(
            '/data1/kangfu/haoyu/Data/Reflection/testdata_CEILNET_table2/')
    except:
        pass
    try:
        os.mkdir(transmission_layer_dir)
    except:
        pass
    try:
        os.mkdir(blended_dir)
    except:
        pass
    for root, _, fnames in sorted(os.walk(datadir)):
        for fname in fnames:
            if is_image_file(fname) and 'label1' in fname:
                path = os.path.join(root, fname)
                target_path = os.path.join(transmission_layer_dir,
                                           fname.replace('-label1', ''))
                copyfile(path, target_path)
            if is_image_file(fname) and 'input' in fname:
                path = os.path.join(root, fname)
                target_path = os.path.join(blended_dir,
                                           fname.replace('-input', ''))
                copyfile(path, target_path)
def make_volume_dataset(dir, max_dataset_size=float("inf")):
    """
    Assemble a nested list of the slices in each observation.
    Adapted from data.image_folder.make_dataset().
    """
    images = []
    regex = re.compile("(\d+)_(\d+)\..*")
    dataset_images = defaultdict(list)
    assert os.path.isdir(dir), '%s is not a valid directory' % dir
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    # sorting first by dataset_id, then by z-slice
    images = sorted(images,
                    key=lambda x:
                    (int(re.match(regex, os.path.basename(x)).group(1)),
                     int(re.match(regex, os.path.basename(x)).group(2))))
    for impath in images:
        bp = os.path.basename(impath)
        res = re.match(regex, bp)
        dataset_id = res.group(1)
        z = res.group(2)
        dataset_images[dataset_id].append(impath)
    if len(dataset_images) > max_dataset_size:
        raise NotImplementedError
    # return just a list of the z-slice filepaths for each dataset
    return list(dataset_images.values())
예제 #3
0
    def initialize(self, opt):
        self.opt = opt
        self.long_term = [0, 2, 4]  #, 8]
        self.root = opt.dataroot
        self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A_video')
        self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')

        self.A_paths = make_dataset(self.dir_A)
        self.B_paths = make_dataset(self.dir_B)

        self.A_paths = sorted(self.A_paths)
        self.B_paths = sorted(self.B_paths)
        # self.A_paths = list(filter(lambda x : random.randint(0,2)==0 or x.find('.avi') != -1, self.A_paths))
        self.A_length = [0]
        self.trans = get_transform(opt)
        for i in self.A_paths:
            if is_image_file(i):
                self.A_length.append(self.A_length[-1] + 1)
            else:
                cap = cv2.VideoCapture(i)
                self.A_length.append(
                    self.A_length[-1] +
                    math.floor(cap.get(cv2.CAP_PROP_FRAME_COUNT) / 10))
        self.A_size = self.A_length[-1]
        self.B_size = len(self.B_paths)
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot

        im_paths = []
        assert os.path.isdir(self.root), '%s is not a valid directory' % self.root
        with open(opt.anno_file, 'r') as F:
            anno_info = []
            for line in F:
                temp = [x for x in line[11:-1].split(' ')]
                anno = [int(x) for x in temp if x !='']
                anno_info.append(anno)

        if opt.phase == 'train':    # 1:162770
            start_id, end_id = 1, opt.train_size
            anno_info = anno_info[:opt.train_size]

        elif opt.phase == 'test':   # 182638:202599
            start_id, end_id = 182638, 202599
            anno_info = anno_info[start_id-1:end_id]

        for i in range(start_id, end_id+1):
            fname = '{:06d}.jpg'.format(i)
            if is_image_file(fname):
                path = os.path.join(self.root, fname)
                im_paths.append(path)

        # self.im_paths = sorted(im_paths)
        self.im_paths = im_paths
        self.anno_info = anno_info
        self.dataset_size = len(self.im_paths)       # self.A_size
        self.transform = get_transform(opt)
예제 #5
0
    def __getitem__(self, index):
        index_A = index
        A_path, l, r = self.splited[index_A]
        print("load", A_path)
        if is_image_file(A_path):
            img = Image.open(A_path)
            frames = np.array(img)
            frames = frames.astype(np.float32)[np.newaxis, :]
            frames, _ = util.transformbyopt(self.opt, frames)
        else:
            cap = cv2.VideoCapture(A_path)
            frames = list()
            cnt = int(cap.get(7))
            seed = None
            cap.set(cv2.CAP_PROP_POS_FRAMES, l)
            for _ in range(min(r - l, 10000)):
                __, f = cap.read()
                if __:
                    f = f[:, :, -1::-1].astype(np.float32)
                    tmpa, tmpb = util.transformbyopt(
                        self.opt,
                        f,
                        seed,
                        additional_scale=0.7 if
                        (A_path.find("_8") != -1 or A_path.find("_16") != -1
                         or A_path.find("_17") != -1) else 1)
                    frames.append(tmpa)
                    seed = tmpb
                else:
                    print("read error {}/{}frame in {}".format(
                        _ + l, cnt, A_path))
                del f
            cap.release()
            frames = torch.stack(frames).float()
        d = dict()

        if self.opt.which_direction == 'BtoA':
            input_nc = self.opt.output_nc
            output_nc = self.opt.input_nc
        else:
            input_nc = self.opt.input_nc
            output_nc = self.opt.output_nc

        if input_nc == 1:  # RGB to gray
            frames = frames[:, 0:1,
                            ...] * 0.299 + frames[:, 1:2,
                                                  ...] * 0.587 + frames[:, 2:3,
                                                                        ...] * 0.114
        print("load finish", A_path)
        print("load:", frames.shape)
        # print(first.shape, opticalflow.shape)
        d['frames'] = frames
        tmp = os.path.splitext(A_path)
        # d['A_paths'] = (tmp[0], '_{}_{}'.format(l,r), tmp[1])
        d['A_paths'] = (tmp[0], '', tmp[1])
        d['index'] = index
        return d
예제 #6
0
    def make_kitti_dataset(self, dir):
        images = []
        assert os.path.isdir(dir), '%s is not a valid directory' % dir

        for root, _, fnames in sorted(os.walk(dir)):

            for fname in fnames:
                if is_image_file(fname) and (root.find('image_02') >= 0):
                    path = os.path.join(root, fname)
                    images.append(path)

        return images
예제 #7
0
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot
        self.dir_A = os.path.join(opt.dataroot, "testA")
        self.A_paths = make_dataset(self.dir_A)
        B = []
        for i in self.A_paths:
            if i.find(".npy") == -1:
                B.append(i)
        self.A_paths = sorted(B)
        self.transform = get_transform(opt)
        self.splited = list()
        for i in range(len(self.A_paths)):
            if is_image_file(self.A_paths[i]):
                self.splited.append((self.A_paths[i], 0, 1))
            else:
                cap = cv2.VideoCapture(self.A_paths[i])
                cnt = int(cap.get(7))
                cap.release()
                n = math.ceil(cnt / 60)
                n = min(round(cnt / n), 65)

                ################################
                n = cnt

                now = 0
                while now < cnt:
                    if now > 0:
                        exit(0)
                    tmp = now + n
                    if cnt - now < 70 or cnt - now <= n:
                        tmp = cnt
                    elif cnt - now < 2 * n:
                        tmp = (cnt - now) // 2 + now
                    self.splited.append((self.A_paths[i], now, tmp))
                    now = tmp
예제 #8
0
        break
    model.set_input(data)
    with torch.no_grad():
        out = model.test()
    if isinstance(out, bool):
        print(model.get_image_paths()[0], "Failed")
    else:
        out, boxed, sc = out
        all_path = model.get_image_paths()
        img_path = all_path[0] + all_path[2]
        fname = os.path.split(img_path)[-1]

        img = ((model.frames.data.cpu().permute(
            (0, 2, 3, 1)).numpy() + 1) * 127.5).astype(np.uint8)

        if is_image_file(img_path):
            outname = os.path.splitext(fname)[0] + "_ink.png"
            im = Image.fromarray(out[0])
            im.save(os.path.join(opt.results_dir, outname))

            if boxed is not None:
                outname = os.path.splitext(fname)[0] + "_boxed.png"
                im = Image.fromarray(boxed[0])
                im.save(os.path.join(opt.results_dir, outname))

            im = Image.fromarray(img[0])
            im.save(os.path.join(opt.results_dir, fname))
        else:
            cap = cv2.VideoCapture(img_path)
            fps = cap.get(cv2.CAP_PROP_FPS)
            forcc = cv2.VideoWriter_fourcc(*"mp4v")
예제 #9
0
    def __getitem__(self, index):
        d = dict()
        index_A = bisect.bisect_right(self.A_length,
                                      index % self.A_length[-1]) - 1
        A_path = self.A_paths[index_A]
        if is_image_file(A_path):
            print("load image from {}".format(A_path))
            img = np.array(Image.open(A_path))
            img, _ = util.transformbyopt(self.opt, img.astype(np.float32))
            frames = torch.unsqueeze(img, 0)
        else:
            frame_num = (index % self.A_length[-1] -
                         self.A_length[index_A]) * 10 + random.randint(0, 9)
            print("load {} frame from {}".format(frame_num, A_path))
            cap = cv2.VideoCapture(A_path)
            if frame_num > cap.get(7):
                frame_num = int(cap.get(7))
            frames = list()
            seed = None
            for i in self.long_term:
                if i <= frame_num:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num - i)
                    _, f = cap.read()
                    f, seed = util.transformbyopt(
                        self.opt, f[:, :, -1::-1].astype(np.float32), seed)
                    frames.append(f)
            cap.release()
            frames = torch.stack(frames)
        if self.opt.serial_batches:
            index_B = index % self.B_size
        else:
            index_B = random.randint(0, self.B_size - 1)
        B_path = self.B_paths[index_B]
        B_img = np.array(Image.open(B_path)).astype(np.float32)
        # print("load B")
        B, _ = util.transformbyopt(self.opt, B_img)

        if self.opt.which_direction == 'BtoA':
            input_nc = self.opt.output_nc
            output_nc = self.opt.input_nc
        else:
            input_nc = self.opt.input_nc
            output_nc = self.opt.output_nc

        if input_nc == 1:  # RGB to gray
            first = first[0:1, ...] * 0.299 + first[1:2, ...] * 0.587 + first[
                2:3, ...] * 0.114
            second = second[0:1, ...] * 0.299 + second[
                1:2, ...] * 0.587 + second[2:3, ...] * 0.114

        if output_nc == 1:  # RGB to gray
            B = B[:, 0:1, ...] * 0.299 + B[:, 1:2, ...] * \
                0.587 + B[:, 2:3, ...] * 0.114
        # print(first.shape, opticalflow.shape)

        d['frames'] = frames
        d['B'] = B
        d['A_paths'] = A_path
        d['B_paths'] = B_path
        d['index'] = index
        return d