Пример #1
0
    def __init__(self, opt):

        BaseDataset.__init__(self, opt)

        # 10000 is the max dataset size

        self.dir_ourdataset = os.path.join(opt.dataroot, 'qualitative1')
        self.images_dir_all = sorted(make_dataset(self.dir_ourdataset, 100000))

        self.data_size = opt.load_size
        self.data_root = opt.dataroot
        self.dark_coef = opt.darken

        opt_merge = copy.deepcopy(opt)
        opt_merge.isTrain = False
        opt_merge.model = 'pix2pix4depth'
        self.mergenet = Pix2Pix4DepthModel(opt_merge)
        self.mergenet.save_dir = 'depthmerge/checkpoints/scaled_04_1024'
        self.mergenet.load_networks('latest')
        self.mergenet.eval()

        self.device = torch.device('cuda:0')

        midas_model_path = "midas/model-f46da743.pt"
        self.midasmodel = MidasNet(midas_model_path, non_negative=True)
        self.midasmodel.to(self.device)
        self.midasmodel.eval()

        torch.multiprocessing.set_start_method('spawn')
Пример #2
0
    def __init__(self, opt):

        BaseDataset.__init__(self, opt)
        self.dir_AB = os.path.join(opt.dataroot,
                                   opt.phase)  # get the image directory
        self.AB_paths = sorted(make_dataset(
            self.dir_AB,
            opt.max_dataset_size))  # load images from '/path/to/data/trainB'

        assert (self.opt.load_size >= self.opt.crop_size
                )  # crop_size should be smaller than the size of loaded image
        self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
        self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc

        opt_merge = copy.deepcopy(opt)
        opt_merge.isTrain = False
        opt_merge.model = 'pix2pix4depth'
        self.mergenet = Pix2Pix4DepthModel(opt_merge)
        self.mergenet.save_dir = 'depthmerge/checkpoints/scaled_04_1024'
        self.mergenet.load_networks('latest')
        self.mergenet.eval()

        self.device = self.mergenet.device

        midas_model_path = "midas/model-f46da743.pt"
        self.midasmodel = MidasNet(midas_model_path, non_negative=True)
        self.midasmodel.to(self.device)
        self.midasmodel.eval()

        torch.multiprocessing.set_start_method('spawn')
    def __init__(self, opt):

        BaseDataset.__init__(self, opt)

        # 10000 is the max dataset size
        if opt.phase == 'test':
            self.dir_ourdataset = os.path.join(opt.dataroot,
                                               'our_dataset2_test')
            self.images_dir_ourdataset = sorted(
                make_dataset(self.dir_ourdataset + '/amb_0.5', 100000))

            self.dir_multidataset = os.path.join(opt.dataroot,
                                                 'multi_dataset_test')
            self.images_dir_multidataset = sorted(
                make_dataset(self.dir_multidataset + '/amb_0.5/1', 100000))
            # self.images_dir_multidataset = self.images_dir_multidataset * 4

            self.dir_portraitdataset = os.path.join(
                opt.dataroot, 'portrait_dataset_extra_test')
            self.images_dir_portraitdataset = sorted(
                make_dataset(self.dir_portraitdataset + '/amb_0.5/1', 100000))
            # self.images_dir_portraitdataset =  self.images_dir_portraitdataset*4
        else:
            self.dir_ourdataset = os.path.join(opt.dataroot, 'our_dataset2')
            self.images_dir_ourdataset = sorted(
                make_dataset(self.dir_ourdataset + '/amb_0.5', 100000))

            self.dir_multidataset = os.path.join(opt.dataroot, 'multi_dataset')
            self.images_dir_multidataset = sorted(
                make_dataset(self.dir_multidataset + '/amb_0.5/1', 100000))
            self.images_dir_multidataset = self.images_dir_multidataset * 4

            self.dir_portraitdataset = os.path.join(opt.dataroot,
                                                    'portrait_dataset_extra')
            self.images_dir_portraitdataset = sorted(
                make_dataset(self.dir_portraitdataset + '/amb_0.5/1', 100000))
            self.images_dir_portraitdataset = self.images_dir_portraitdataset * 4

        self.images_dir_all = self.images_dir_ourdataset + self.images_dir_multidataset + self.images_dir_portraitdataset

        self.data_size = opt.load_size
        self.data_root = opt.dataroot

        opt_merge = copy.deepcopy(opt)
        opt_merge.isTrain = False
        opt_merge.model = 'pix2pix4depth'
        self.mergenet = Pix2Pix4DepthModel(opt_merge)
        self.mergenet.save_dir = 'depthmerge/checkpoints/scaled_04_1024'
        self.mergenet.load_networks('latest')
        self.mergenet.eval()

        self.device = torch.device('cuda:0')

        midas_model_path = "midas/model-f46da743.pt"
        self.midasmodel = MidasNet(midas_model_path, non_negative=True)
        self.midasmodel.to(self.device)
        self.midasmodel.eval()

        torch.multiprocessing.set_start_method('spawn')
        #
        # for i in range(len(self.images_dir_all)):
        #     self.__getitem__(i)

        images_dir_all_picked = []

        for i in range(len(self.images_dir_all)):
            image_path_temp = self.images_dir_all[i]
            image_name = image_path_temp.split('/')[-1]

            amb_select = random.randint(0, 2)
            if amb_select == 0:
                amb_dir = '/amb_0.5'
            elif amb_select == 1:
                amb_dir = '/amb_0.75'
            elif amb_select == 2:
                amb_dir = '/amb_1'

            if self.opt.phase == 'test':
                if 'our_dataset2_test/' in image_path_temp:
                    image_path = self.data_root + '/our_dataset2_test' + amb_dir + '/{}'.format(
                        image_name)
                elif 'multi_dataset_test/' in image_path_temp:
                    multi_select = random.randint(1, 10)
                    image_path = self.data_root + '/multi_dataset_test' + amb_dir + '/{}'.format(
                        multi_select) + '/{}'.format(image_name)
                elif 'portrait_dataset_extra_test/' in image_path_temp:
                    portrait_select = random.randint(1, 20)
                    image_path = self.data_root + '/portrait_dataset_extra_test' + amb_dir + '/{}'.format(
                        portrait_select) + '/{}'.format(image_name)
            else:
                if 'our_dataset2/' in image_path_temp:
                    image_path = self.data_root + '/our_dataset2' + amb_dir + '/{}'.format(
                        image_name)
                elif 'multi_dataset/' in image_path_temp:
                    multi_select = random.randint(1, 10)
                    image_path = self.data_root + '/multi_dataset' + amb_dir + '/{}'.format(
                        multi_select) + '/{}'.format(image_name)
                elif 'portrait_dataset_extra/' in image_path_temp:
                    portrait_select = random.randint(1, 20)
                    image_path = self.data_root + '/portrait_dataset_extra' + amb_dir + '/{}'.format(
                        portrait_select) + '/{}'.format(image_name)
            print('Replacing | ', image_path_temp, '-->', image_path)
            images_dir_all_picked.append(image_path)

        self.images_dir_all = images_dir_all_picked
class FixedRandomDataset(BaseDataset):
    def __init__(self, opt):

        BaseDataset.__init__(self, opt)

        # 10000 is the max dataset size
        if opt.phase == 'test':
            self.dir_ourdataset = os.path.join(opt.dataroot,
                                               'our_dataset2_test')
            self.images_dir_ourdataset = sorted(
                make_dataset(self.dir_ourdataset + '/amb_0.5', 100000))

            self.dir_multidataset = os.path.join(opt.dataroot,
                                                 'multi_dataset_test')
            self.images_dir_multidataset = sorted(
                make_dataset(self.dir_multidataset + '/amb_0.5/1', 100000))
            # self.images_dir_multidataset = self.images_dir_multidataset * 4

            self.dir_portraitdataset = os.path.join(
                opt.dataroot, 'portrait_dataset_extra_test')
            self.images_dir_portraitdataset = sorted(
                make_dataset(self.dir_portraitdataset + '/amb_0.5/1', 100000))
            # self.images_dir_portraitdataset =  self.images_dir_portraitdataset*4
        else:
            self.dir_ourdataset = os.path.join(opt.dataroot, 'our_dataset2')
            self.images_dir_ourdataset = sorted(
                make_dataset(self.dir_ourdataset + '/amb_0.5', 100000))

            self.dir_multidataset = os.path.join(opt.dataroot, 'multi_dataset')
            self.images_dir_multidataset = sorted(
                make_dataset(self.dir_multidataset + '/amb_0.5/1', 100000))
            self.images_dir_multidataset = self.images_dir_multidataset * 4

            self.dir_portraitdataset = os.path.join(opt.dataroot,
                                                    'portrait_dataset_extra')
            self.images_dir_portraitdataset = sorted(
                make_dataset(self.dir_portraitdataset + '/amb_0.5/1', 100000))
            self.images_dir_portraitdataset = self.images_dir_portraitdataset * 4

        self.images_dir_all = self.images_dir_ourdataset + self.images_dir_multidataset + self.images_dir_portraitdataset

        self.data_size = opt.load_size
        self.data_root = opt.dataroot

        opt_merge = copy.deepcopy(opt)
        opt_merge.isTrain = False
        opt_merge.model = 'pix2pix4depth'
        self.mergenet = Pix2Pix4DepthModel(opt_merge)
        self.mergenet.save_dir = 'depthmerge/checkpoints/scaled_04_1024'
        self.mergenet.load_networks('latest')
        self.mergenet.eval()

        self.device = torch.device('cuda:0')

        midas_model_path = "midas/model-f46da743.pt"
        self.midasmodel = MidasNet(midas_model_path, non_negative=True)
        self.midasmodel.to(self.device)
        self.midasmodel.eval()

        torch.multiprocessing.set_start_method('spawn')
        #
        # for i in range(len(self.images_dir_all)):
        #     self.__getitem__(i)

        images_dir_all_picked = []

        for i in range(len(self.images_dir_all)):
            image_path_temp = self.images_dir_all[i]
            image_name = image_path_temp.split('/')[-1]

            amb_select = random.randint(0, 2)
            if amb_select == 0:
                amb_dir = '/amb_0.5'
            elif amb_select == 1:
                amb_dir = '/amb_0.75'
            elif amb_select == 2:
                amb_dir = '/amb_1'

            if self.opt.phase == 'test':
                if 'our_dataset2_test/' in image_path_temp:
                    image_path = self.data_root + '/our_dataset2_test' + amb_dir + '/{}'.format(
                        image_name)
                elif 'multi_dataset_test/' in image_path_temp:
                    multi_select = random.randint(1, 10)
                    image_path = self.data_root + '/multi_dataset_test' + amb_dir + '/{}'.format(
                        multi_select) + '/{}'.format(image_name)
                elif 'portrait_dataset_extra_test/' in image_path_temp:
                    portrait_select = random.randint(1, 20)
                    image_path = self.data_root + '/portrait_dataset_extra_test' + amb_dir + '/{}'.format(
                        portrait_select) + '/{}'.format(image_name)
            else:
                if 'our_dataset2/' in image_path_temp:
                    image_path = self.data_root + '/our_dataset2' + amb_dir + '/{}'.format(
                        image_name)
                elif 'multi_dataset/' in image_path_temp:
                    multi_select = random.randint(1, 10)
                    image_path = self.data_root + '/multi_dataset' + amb_dir + '/{}'.format(
                        multi_select) + '/{}'.format(image_name)
                elif 'portrait_dataset_extra/' in image_path_temp:
                    portrait_select = random.randint(1, 20)
                    image_path = self.data_root + '/portrait_dataset_extra' + amb_dir + '/{}'.format(
                        portrait_select) + '/{}'.format(image_name)
            print('Replacing | ', image_path_temp, '-->', image_path)
            images_dir_all_picked.append(image_path)

        self.images_dir_all = images_dir_all_picked

    def __getitem__(self, index):

        image_path = self.images_dir_all[index]

        if 'our_dataset' in image_path:
            image_pair = Image.open(image_path)
            hyper_des = int(PngImageFile(image_path).text['des'])

            A, B = self.divide_imagepair(image_pair)
            ambient = skimage.img_as_float(A)
            flash = skimage.img_as_float(B)

            if hyper_des == 21:
                flash = self.changeTemp(flash, 48, hyper_des)
                ambient = self.changeTemp(ambient, 48, hyper_des)
            flashPhoto = flash + ambient
            flashPhoto[flashPhoto < 0] = 0
            flashPhoto[flashPhoto > 1] = 1
            flashPhoto = self.xyztorgb(flashPhoto, hyper_des)
            ambient = self.xyztorgb(ambient, hyper_des)

            flashphoto_depth = self.getDepth(flashPhoto, image_path, 'flash')
            ambient_depth = self.getDepth(ambient, image_path, 'ambient')

        elif 'multi_dataset' in image_path:
            image_pair = Image.open(image_path)

            A, B = self.divide_imagepair(image_pair)
            ambient = skimage.img_as_float(A)
            flash = skimage.img_as_float(B)

            flashPhoto = flash + ambient
            flashPhoto[flashPhoto < 0] = 0
            flashPhoto[flashPhoto > 1] = 1
            ambient = Image.fromarray((ambient * 255).astype('uint8'))
            flashPhoto = Image.fromarray((flashPhoto * 255).astype('uint8'))

            flashphoto_depth = self.getDepth(flashPhoto, image_path, 'flash')
            ambient_depth = self.getDepth(ambient, image_path, 'ambient')

        elif 'portrait_dataset' in image_path:
            image_pair = Image.open(image_path)

            A, B = self.divide_imagepair(image_pair)
            ambient = skimage.img_as_float(A)
            flash = skimage.img_as_float(B)

            ambient = self.lin(ambient)
            flash = self.lin(flash)

            flashPhoto = flash + ambient
            flashPhoto[flashPhoto < 0] = 0
            flashPhoto[flashPhoto > 1] = 1
            ambient = Image.fromarray((ambient * 255).astype('uint8'))
            flashPhoto = Image.fromarray((flashPhoto * 255).astype('uint8'))

            flashphoto_depth = self.getDepth(flashPhoto, image_path, 'flash')
            ambient_depth = self.getDepth(ambient, image_path, 'ambient')

        torch.cuda.empty_cache()

        ambient_orgsize = skimage.img_as_float(ambient)
        flashPhoto_orgsize = skimage.img_as_float(flashPhoto)

        ambient = ambient.resize((self.data_size, self.data_size))
        flashPhoto = flashPhoto.resize((self.data_size, self.data_size))
        ambient_depth = ambient_depth.resize((self.data_size, self.data_size))
        flashphoto_depth = flashphoto_depth.resize(
            (self.data_size, self.data_size))

        transform_params = get_params(self.opt, ambient.size)
        rgb_transform = get_transform(self.opt,
                                      transform_params,
                                      grayscale=False)
        depth_transform = get_transform(self.opt,
                                        transform_params,
                                        grayscale=True)

        ambient = rgb_transform(ambient)
        flashPhoto = rgb_transform(flashPhoto)

        ambient_depth = depth_transform(ambient_depth)
        flashphoto_depth = depth_transform(flashphoto_depth)

        return {
            'A': flashPhoto,
            'B': ambient,
            'A_org': flashPhoto_orgsize,
            'B_org': ambient_orgsize,
            'depth_A': flashphoto_depth,
            'depth_B': ambient_depth,
            'A_paths': image_path,
            'B_paths': image_path
        }

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.images_dir_all)

    def getDepth(self, image, image_path, id):
        image_path_beheaded = image_path.replace('.png', '')
        image_path_beheaded = image_path_beheaded.replace('.jpg', '')
        image_path_beheaded = image_path_beheaded.replace('.jpeg', '')

        depth_path = image_path_beheaded[:len(
            self.data_root)] + '/depth' + image_path_beheaded[
                len(self.data_root):] + '_' + id + '.png'
        if os.path.exists(depth_path):
            depth = Image.open(depth_path)
        else:
            depth = self.estimateDepth(np.asarray(image) / 255)
            depth_dir = depth_path.replace(depth_path.split('/')[-1], '')
            if not os.path.exists(depth_dir):
                os.makedirs(depth_dir)
            depth = (depth * 255).astype('uint8')
            cv2.imwrite(depth_path, depth)
            depth = Image.fromarray(depth)
            # depth.save(depth_path)
            print('Depth file cached |', depth_path)

        return depth

    def divide_imagepair(self, image_pair):
        w, h = image_pair.size
        w2 = int(w / 2)
        A = image_pair.crop((0, 0, w2, h))
        B = image_pair.crop((w2, 0, w, h))
        # A = A.resize((self.data_size, self.data_size))
        # B = B.resize((self.data_size, self.data_size))
        return A, B

    def gama_corect(self, rgb):
        srgb = np.zeros_like(rgb)
        mask1 = (rgb > 0) * (rgb < 0.0031308)
        mask2 = (1 - mask1).astype(bool)
        srgb[mask1] = 12.92 * rgb[mask1]
        srgb[mask2] = 1.055 * np.power(rgb[mask2], 0.41666) - 0.055
        srgb[srgb < 0] = 0
        return srgb

    def doubleestimate(self, img, size1, size2):
        estimate1 = self.singleestimate(img, size1)
        estimate1 = cv2.resize(estimate1, (1024, 1024),
                               interpolation=cv2.INTER_CUBIC)

        estimate2 = self.singleestimate(img, size2)
        estimate2 = cv2.resize(estimate2, (1024, 1024),
                               interpolation=cv2.INTER_CUBIC)

        self.mergenet.set_input(estimate1, estimate2)
        self.mergenet.test()
        torch.cuda.empty_cache()
        visuals = self.mergenet.get_current_visuals()
        prediction_mapped = visuals['fake_B']
        prediction_mapped = (prediction_mapped + 1) / 2
        prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)
                             ) / (torch.max(prediction_mapped) -
                                  torch.min(prediction_mapped))
        prediction_mapped = prediction_mapped.squeeze().cpu().numpy()

        prediction_end_res = cv2.resize(prediction_mapped,
                                        (img.shape[1], img.shape[0]),
                                        interpolation=cv2.INTER_CUBIC)

        return prediction_end_res

    def singleestimate(self, img, msize):
        return self.estimateMidas(img, msize)

    def estimateMidas(self, img, msize):
        transform = Compose([
            Resize(
                msize,
                msize,
                resize_target=None,
                keep_aspect_ratio=True,
                ensure_multiple_of=32,
                resize_method="upper_bound",
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            NormalizeImage(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
            PrepareForNet(),
        ])

        img_input = transform({"image": img})["image"]
        # compute
        with torch.no_grad():
            sample = torch.from_numpy(img_input).to(self.device).unsqueeze(0)
            prediction = self.midasmodel.forward(sample)
            torch.cuda.empty_cache()

        prediction = prediction.squeeze().cpu().numpy()
        prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]),
                                interpolation=cv2.INTER_CUBIC)

        depth_min = prediction.min()
        depth_max = prediction.max()

        if depth_max - depth_min > np.finfo("float").eps:
            prediction = (prediction - depth_min) / (depth_max - depth_min)
        else:
            prediction = 0

        return prediction

    def estimateDepth(self, rgb_mix):
        rgb_mix = self.gama_corect(rgb_mix)
        depth_temp = self.doubleestimate(rgb_mix, 384, 768)
        return depth_temp

    def xyztorgb(self, image, des):
        illum = self.chromaticityAdaptation(des)
        mat = [[3.2404542, -0.9692660, 0.0556434],
               [-1.5371385, 1.8760108, -0.2040259],
               [-0.4985314, 0.0415560, 1.0572252]]
        image = np.matmul(image, illum)
        image = np.matmul(image, mat)
        image = np.where(image < 0, 0, image)
        image = np.where(image > 1, 1, image)
        out = (image * 255).astype('uint8')
        out = Image.fromarray(out)
        return out

    def chromaticityAdaptation(self, calibrationIlluminant):
        if (calibrationIlluminant == 17):
            illum = [[0.8652435, 0.0000000, 0.0000000],
                     [0.0000000, 1.0000000, 0.0000000],
                     [0.0000000, 0.0000000, 3.0598005]]
        elif (calibrationIlluminant == 19):
            illum = [[0.9691356, 0.0000000, 0.0000000],
                     [0.0000000, 1.0000000, 0.0000000],
                     [0.0000000, 0.0000000, 0.9209267]]
        elif (calibrationIlluminant == 20):
            illum = [[0.9933634, 0.0000000, 0.0000000],
                     [0.0000000, 1.0000000, 0.0000000],
                     [0.0000000, 0.0000000, 1.1815972]]
        elif (calibrationIlluminant == 21):
            illum = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
        elif (calibrationIlluminant == 23):
            illum = [[1.0077340, 0.0000000, 0.0000000],
                     [0.0000000, 1.0000000, 0.0000000],
                     [0.0000000, 0.0000000, 0.8955170]]
        return illum

    def getRatio(self, t, low, high):
        dist = t - low
        range = (high - low) / 100
        return dist / range

    def changeTemp(self, image, tempChange, des):
        if (des == 17):
            t1 = 5500
            if tempChange == 44:
                tempChange = -400
            elif tempChange == 40:
                tempChange = -450
            elif tempChange == 52:
                tempChange = 234
            elif tempChange == 54:
                tempChange = 468
            t = t1 + tempChange
            if t <= 10000 and t > 6500:
                r = self.getRatio(t, 6500, 10000)
                xD = 0.3
                yD = 0.3
                xS = 0.3118
                yS = 0.3224
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            elif 6500 >= t and t > 5500:
                r = self.getRatio(t, 5500, 6500)
                xD = 0.3118
                yD = 0.3224
                xS = 0.3580
                yS = 0.3239
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            elif 5000 <= t and t < 5500:
                r = self.getRatio(t, 5500, 5000)
                xD = 0.3752
                yD = 0.3238
                xS = 0.3580
                yS = 0.3239
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            elif (t >= 4500 and t < 5000):
                r = self.getRatio(t, 5000, 4500)
                xD = 0.4231
                yD = 0.3304
                xS = 0.3752
                yS = 0.3238
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            elif t < 4500 and t >= 4000:
                r = self.getRatio(t, 4500, 4000)
                xD = 0.4949
                yD = 0.3564
                xS = 0.4231
                yS = 0.3304
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            elif t >= 0 and t < 4000:
                r = self.getRatio(t, 4000, 0)
                xD = 0.5041
                yD = 0.3334
                xS = 0.4949
                yS = 0.3564
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            chromaticity_x = 0.3580
            chromaticity_y = 0.3239
        elif des == 21:
            t1 = 4500
            if tempChange == 48:
                tempChange = 700
            if tempChange == 44:
                tempChange = 400
            elif tempChange == 40:
                tempChange = 250
            elif tempChange == 52:
                tempChange = 800
            elif tempChange == 54:
                tempChange = 1000

            t = tempChange + t1
            if t >= 4500 and t <= 7500:
                r = self.getRatio(t, 4500, 7500)
                xD = 0.17
                yD = 0.17
                xS = 0.4231
                yS = 0.3304
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            elif (t < 4500 and t >= 4000):
                r = self.getRatio(t, 4500, 4000)
                xD = 0.4949
                yD = 0.3564
                xS = 0.4231
                yS = 0.3304
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            elif (t >= 3500 and t < 4000):
                r = self.getRatio(t, 4000, 3500)
                xD = 0.5141
                yD = 0.3434
                xS = 0.4949
                yS = 0.3564
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y
            elif (t > 0 and t < 3500):
                r = self.getRatio(t, 3500, 0)
                xD = 0.5189
                yD = 0.3063
                xS = 0.5141
                yS = 0.3434
                r_x = (xD - xS) / 100
                xD = xS + r * r_x
                r_y = (yD - yS) / 100
                yD = yS + r * r_y

            chromaticity_x = 0.4231

            chromaticity_y = 0.3304

        offset_x = xD / chromaticity_x
        offset_y = yD / chromaticity_y

        out = image
        h, w, c = image.shape
        img0 = image[:, :, 0]
        img1 = image[:, :, 1]
        img2 = image[:, :, 2]
        sumImage = img0 + img1 + img2
        x_pix = np.zeros((h, w))
        y_pix = np.zeros((h, w))

        nonZeroSum = np.where(sumImage != 0)
        x_pix[nonZeroSum] = img0[nonZeroSum] / sumImage[nonZeroSum]
        y_pix[nonZeroSum] = img1[nonZeroSum] / sumImage[nonZeroSum]

        x_pix = x_pix * offset_x
        y_pix = y_pix * offset_y

        out0 = np.zeros((h, w))
        out2 = np.zeros((h, w))

        nonZeroY = np.where(y_pix != 0)
        ones = np.ones((h, w))
        out0[nonZeroY] = x_pix[nonZeroY] * img1[nonZeroY] / y_pix[nonZeroY]
        out2[nonZeroY] = (ones[nonZeroY] - x_pix[nonZeroY] -
                          y_pix[nonZeroY]) * img1[nonZeroY] / y_pix[nonZeroY]
        out[:, :, 0] = out0
        out[:, :, 2] = out2

        return out

    def lin(self, srgb):
        srgb = srgb.astype(np.float)
        rgb = np.zeros_like(srgb).astype(np.float)
        srgb = srgb
        mask1 = srgb <= 0.04045
        mask2 = (1 - mask1).astype(bool)
        rgb[mask1] = srgb[mask1] / 12.92
        rgb[mask2] = ((srgb[mask2] + 0.055) / 1.055)**2.4
        rgb = rgb
        return rgb
Пример #5
0
opt_merge.num_threads = 0  # test code only supports num_threads = 1
opt_merge.batch_size = 1  # test code only supports batch_size = 1
opt_merge.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt_merge.no_flip = True  # no flip; comment this line if results on flipped images are needed.
opt_merge.display_id = -1  # no visdom display; the test code saves the results to a HTML file.
opt_merge.isTrain = False
opt_merge.model = 'pix2pix4depth'
mergenet = Pix2Pix4DepthModel(opt_merge)
mergenet.save_dir = 'depthmerge/checkpoints/scaled_04_1024'
mergenet.load_networks('latest')
mergenet.eval()

device = torch.device('cuda:0')

midas_model_path = "midas/model-f46da743.pt"
midasmodel = MidasNet(midas_model_path, non_negative=True)
midasmodel.to(device)
midasmodel.eval()

# define hyper-parameters
ref_size = 512

# define image to tensor transform
im_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda()
Пример #6
0
class AlignedLabDataset(BaseDataset):
    def __init__(self, opt):

        BaseDataset.__init__(self, opt)
        self.dir_AB = os.path.join(opt.dataroot,
                                   opt.phase)  # get the image directory
        self.AB_paths = sorted(make_dataset(
            self.dir_AB,
            opt.max_dataset_size))  # load images from '/path/to/data/trainB'

        assert (self.opt.load_size >= self.opt.crop_size
                )  # crop_size should be smaller than the size of loaded image
        self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
        self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc

        opt_merge = copy.deepcopy(opt)
        opt_merge.isTrain = False
        opt_merge.model = 'pix2pix4depth'
        self.mergenet = Pix2Pix4DepthModel(opt_merge)
        self.mergenet.save_dir = 'depthmerge/checkpoints/scaled_04_1024'
        self.mergenet.load_networks('latest')
        self.mergenet.eval()

        self.device = self.mergenet.device

        midas_model_path = "midas/model-f46da743.pt"
        self.midasmodel = MidasNet(midas_model_path, non_negative=True)
        self.midasmodel.to(self.device)
        self.midasmodel.eval()

        torch.multiprocessing.set_start_method('spawn')

    def gama_corect(self, rgb):
        srgb = np.zeros_like(rgb)
        mask1 = (rgb > 0) * (rgb < 0.0031308)
        mask2 = (1 - mask1).astype(bool)
        srgb[mask1] = 12.92 * rgb[mask1]
        srgb[mask2] = 1.055 * np.power(rgb[mask2], 0.41666) - 0.055
        srgb[srgb < 0] = 0
        return srgb

    def doubleestimate(self, img, size1, size2):
        estimate1 = self.singleestimate(img, size1)
        estimate1 = cv2.resize(estimate1, (1024, 1024),
                               interpolation=cv2.INTER_CUBIC)

        estimate2 = self.singleestimate(img, size2)
        estimate2 = cv2.resize(estimate2, (1024, 1024),
                               interpolation=cv2.INTER_CUBIC)

        self.mergenet.set_input(estimate1, estimate2)
        self.mergenet.test()
        visuals = self.mergenet.get_current_visuals()
        prediction_mapped = visuals['fake_B']
        prediction_mapped = (prediction_mapped + 1) / 2
        prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)
                             ) / (torch.max(prediction_mapped) -
                                  torch.min(prediction_mapped))
        prediction_mapped = prediction_mapped.squeeze().cpu().numpy()

        prediction_end_res = cv2.resize(prediction_mapped,
                                        (img.shape[1], img.shape[0]),
                                        interpolation=cv2.INTER_CUBIC)

        return prediction_end_res

    def singleestimate(self, img, msize):
        return self.estimateMidas(img, msize)

    def estimateMidas(self, img, msize):
        transform = Compose([
            Resize(
                msize,
                msize,
                resize_target=None,
                keep_aspect_ratio=True,
                ensure_multiple_of=32,
                resize_method="upper_bound",
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            NormalizeImage(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
            PrepareForNet(),
        ])

        img_input = transform({"image": img})["image"]
        # compute
        with torch.no_grad():
            sample = torch.from_numpy(img_input).to(self.device).unsqueeze(0)
            prediction = self.midasmodel.forward(sample)

        prediction = prediction.squeeze().cpu().numpy()
        prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]),
                                interpolation=cv2.INTER_CUBIC)

        depth_min = prediction.min()
        depth_max = prediction.max()

        if depth_max - depth_min > np.finfo("float").eps:
            prediction = (prediction - depth_min) / (depth_max - depth_min)
        else:
            prediction = 0

        return prediction

    def estimateDepth(self, rgb_mix):
        # rgb_mix = (rgb_mix + 1) / 2
        # rgb_mix = rgb_mix.cpu().numpy()
        # rgb_mix = np.transpose(rgb_mix, (1, 2, 0))
        rgb_mix = self.gama_corect(rgb_mix)
        # print(rgb_mix.shape)
        # showImage(rgb_mix)
        depth_temp = self.doubleestimate(rgb_mix, 256, 512)
        # depth_temp = self.doubleestimate(rgb_mix, 384, 768)
        # showImage(depth_temp)
        return depth_temp

    def __getitem__(self, index):
        # read a image given a random integer index
        AB_path = self.AB_paths[index]
        # print(AB_path)

        # print(index)
        if 'flash' in AB_path.split('/')[-1]:
            roomDir = 'Rooms'

            roomDir = os.path.join(self.opt.dataroot, roomDir)
            # print(roomDir)
            # print("here")
            roomImages = os.listdir(roomDir)
            flash_color_adjustment_ratio = [1.23, 0.8, 1.04]
            room_image = roomImages[index % (len(roomImages) - 1)]
            path_room = os.path.join(roomDir, room_image)
            room = Image.open(path_room)
            targetImage = PngImageFile(path_room)
            des = int(targetImage.text['des'])
            w, h = room.size
            w2 = int(w / 2)
            A = room.crop((0, 0, w2, h))
            B = room.crop((w2, 0, w, h))
            A = A.crop((0, 0, 256, 256))
            B = B.crop((0, 0, 256, 256))

            flash = skimage.img_as_float(B)
            ambient = skimage.img_as_float(A)

            flash = makeBright(flash, 1.5)
            #
            # flash_avg = getAverage(flash[:, :, 0]) + getAverage(flash[:, :, 1]) + getAverage(flash[:, :, 2])
            # ambient_avg = getAverage(ambient[:, :, 0]) + getAverage(ambient[:, :, 1]) + getAverage(ambient[:, :, 2])
            # flash = flash * 2 * ambient_avg / (flash_avg + 0.0001)

            if des == 21:
                flash = changeTemp(flash, 48, des)
                ambient = changeTemp(ambient, 48, des)
            flash = flash + ambient
            flash = xyztorgb(flash, des)
            ambient = xyztorgb(ambient, des)

            # path_people = os.path.join(peopleDir, path)
            people_pic = Image.open(AB_path)
            w, h = people_pic.size
            w2 = int(w / 2)
            ambient_pic = people_pic.crop((0, 0, w2, h))
            flash_pic = people_pic.crop((w2, 0, w, h))

            flash_pic = flash_pic.resize((256, 256))
            ambient_pic = ambient_pic.resize((256, 256))
            flash_color_adjustment_ratio = flash_color_adjustment_ratio / np.max(
                flash_color_adjustment_ratio)
            flash_pic = lin(skimage.img_as_float(flash_pic))
            flash_pic[:, :,
                      0] = flash_pic[:, :, 0] * flash_color_adjustment_ratio[0]
            flash_pic[:, :,
                      1] = flash_pic[:, :, 1] * flash_color_adjustment_ratio[1]
            flash_pic[:, :,
                      2] = flash_pic[:, :, 2] * flash_color_adjustment_ratio[2]
            ambient_pic_lin = ambient_pic.copy()
            ambient_pic_lin = lin(skimage.img_as_float(ambient_pic_lin))
            flash_lin = lin(skimage.img_as_float(flash.copy()))
            ambient_lin = lin(skimage.img_as_float(ambient.copy()))

            avg_ambient_red_bg = getAverage(ambient_lin[:, :, 0])
            avg_ambient_green_bg = getAverage(ambient_lin[:, :, 1])
            avg_ambient_blue_bg = getAverage(ambient_lin[:, :, 2])

            avg_flash_red_bg = getAverage(flash_lin[:, :, 0])
            avg_flash_green_bg = getAverage(flash_lin[:, :, 1])
            avg_flash_blue_bg = getAverage(flash_lin[:, :, 2])

            bg_ratio_red = avg_ambient_red_bg / avg_flash_red_bg
            bg_ratio_green = avg_ambient_green_bg / avg_flash_green_bg
            bg_ratio_blue = avg_ambient_blue_bg / avg_flash_blue_bg

            ambient_pic_lin[:, :, 0] = ambient_pic_lin[:, :, 0] * bg_ratio_red
            ambient_pic_lin[:, :,
                            1] = ambient_pic_lin[:, :, 1] * bg_ratio_green
            ambient_pic_lin[:, :, 2] = ambient_pic_lin[:, :, 2] * bg_ratio_blue

            # ambient_pic_adjust = gama_corect(ambient_pic_lin)
            ambient_pic_adjust = Image.fromarray(
                (ambient_pic_lin * 255).astype('uint8'))

            # flash_pic = gama_corect(flash_pic)
            flash_pic = Image.fromarray((flash_pic * 255).astype('uint8'))

            flash_out = alpha_blend(flash_pic, flash)
            ambient_out_adjust = alpha_blend(ambient_pic_adjust, ambient)

            flash_out = (flash_out * 255).astype('uint8')
            flash_out = Image.fromarray(flash_out)
            ambient_out_adjust = (ambient_out_adjust * 255).astype('uint8')
            ambient_out_adjust = Image.fromarray(ambient_out_adjust)

            A = flash_out
            B = ambient_out_adjust
        elif 'multi' in AB_path.split('/')[-1]:
            AB = Image.open(AB_path)
            w, h = AB.size
            w2 = int(w / 2)
            A = AB.crop((0, 0, w2, h))
            B = AB.crop((w2, 0, w, h))
            A = A.resize((256, 256))
            B = B.resize((256, 256))
            flash = lin(skimage.img_as_float(B))
            ambient = lin(skimage.img_as_float(A))

            rest_path = AB_path.replace('train', 'rest_multi').replace(
                '.jpg', '') + '_ambient.jpg'
            rest_Ambient = Image.open(rest_path)
            w, h = rest_Ambient.size
            w4 = int(w / 4)
            A2 = rest_Ambient.crop((0, 0, w4, h))
            A3 = rest_Ambient.crop((w4, 0, 2 * w4, h))
            A4 = rest_Ambient.crop((2 * w4, 0, 3 * w4, h))
            A5 = rest_Ambient.crop((3 * w4, 0, w, h))

            ambient2 = lin(skimage.img_as_float(A2))
            ambient3 = lin(skimage.img_as_float(A3))
            ambient4 = lin(skimage.img_as_float(A4))
            ambient5 = lin(skimage.img_as_float(A5))

            transform_params = get_params(self.opt, A.size)
            depth_transform = get_transform(self.opt,
                                            transform_params,
                                            grayscale=True)
            transform = get_transform(self.opt,
                                      transform_params,
                                      grayscale=(self.input_nc == 1))

            depth_flash = Image.fromarray(
                (self.estimateDepth(flash) * 255).astype('uint8'))
            depth_flash = depth_transform(depth_flash).unsqueeze(0)

            depth_ambient1 = Image.fromarray(
                (self.estimateDepth(ambient) * 255).astype('uint8'))
            depth_ambient2 = Image.fromarray(
                (self.estimateDepth(ambient2) * 255).astype('uint8'))
            depth_ambient3 = Image.fromarray(
                (self.estimateDepth(ambient3) * 255).astype('uint8'))
            depth_ambient4 = Image.fromarray(
                (self.estimateDepth(ambient4) * 255).astype('uint8'))
            depth_ambient5 = Image.fromarray(
                (self.estimateDepth(ambient5) * 255).astype('uint8'))
            depth_ambient1 = depth_transform(depth_ambient1).unsqueeze(0)
            depth_ambient2 = depth_transform(depth_ambient2).unsqueeze(0)
            depth_ambient3 = depth_transform(depth_ambient3).unsqueeze(0)
            depth_ambient4 = depth_transform(depth_ambient4).unsqueeze(0)
            depth_ambient5 = depth_transform(depth_ambient5).unsqueeze(0)

            flash = (flash * 255).astype('uint8')
            ambient = (ambient * 255).astype('uint8')
            ambient2 = (ambient2 * 255).astype('uint8')
            ambient3 = (ambient3 * 255).astype('uint8')
            ambient4 = (ambient4 * 255).astype('uint8')
            ambient5 = (ambient5 * 255).astype('uint8')

            flash = Image.fromarray(flash)
            ambient = Image.fromarray(ambient)
            ambient2 = Image.fromarray(ambient2)
            ambient3 = Image.fromarray(ambient3)
            ambient4 = Image.fromarray(ambient4)
            ambient5 = Image.fromarray(ambient5)

            flash = transform(flash).unsqueeze(0)
            ambient = transform(ambient).unsqueeze(0)
            ambient2 = transform(ambient2).unsqueeze(0)
            ambient3 = transform(ambient3).unsqueeze(0)
            ambient4 = transform(ambient4).unsqueeze(0)
            ambient5 = transform(ambient5).unsqueeze(0)

            A_final = torch.cat((flash, flash, flash, flash, flash), dim=0)
            B_final = torch.cat(
                (ambient, ambient2, ambient3, ambient4, ambient5), dim=0)

            depth_A_final = torch.cat((depth_flash, depth_flash, depth_flash,
                                       depth_flash, depth_flash),
                                      dim=0)
            depth_B_final = torch.cat(
                (depth_ambient1, depth_ambient2, depth_ambient3,
                 depth_ambient4, depth_ambient5),
                dim=0)

            return {
                'A': A_final,
                'B': B_final,
                'depth_A': depth_A_final,
                'depth_B': depth_B_final,
                'A_paths': AB_path,
                'B_paths': AB_path
            }

        else:
            AB = Image.open(AB_path)
            targetImage = PngImageFile(AB_path)
            des = int(targetImage.text['des'])

            w, h = AB.size
            w2 = int(w / 2)
            A = AB.crop((0, 0, w2, h))
            B = AB.crop((w2, 0, w, h))
            A = A.resize((self.opt.load_size, self.opt.load_size))
            B = B.resize((self.opt.load_size, self.opt.load_size))
            flash = skimage.img_as_float(B)
            ambient = skimage.img_as_float(A)
            flash = makeBright(flash, 1.2)

            # flash_avg = getAverage(flash[:,:,0]) + getAverage(flash[:,:,1]) + getAverage(flash[:,:,2])
            # ambient_avg = getAverage(ambient[:,:,0]) + getAverage(ambient[:,:,1]) + getAverage(ambient[:,:,2])
            # flash = flash * 2 * ambient_avg / (flash_avg + 0.001)

            # ambient = makeBright(ambient,0.5)
            if des == 21:
                flash = changeTemp(flash, 48, des)
                ambient = changeTemp(ambient, 48, des)

            A = flash + ambient
            B = ambient
            A = xyztorgb(A, des)
            B = xyztorgb(B, des)

        transform_params = get_params(self.opt, A.size)
        # apply the same transform to both A and B
        A_transform = get_transform(self.opt,
                                    transform_params,
                                    grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt,
                                    transform_params,
                                    grayscale=(self.output_nc == 1))
        depth_transform = get_transform(self.opt,
                                        transform_params,
                                        grayscale=True)

        depth_A = Image.fromarray(
            (self.estimateDepth(np.asarray(A) / 255) * 255).astype('uint8'))
        depth_B = Image.fromarray(
            (self.estimateDepth(np.asarray(B) / 255) * 255).astype('uint8'))
        # showImage(depth_A)
        # showImage(depth_B)
        depth_A = depth_transform(depth_A)
        depth_B = depth_transform(depth_B)

        A = A_transform(A)
        B = B_transform(B)

        # print(torch.shape(A))
        return {
            'A': A,
            'B': B,
            'depth_A': depth_A,
            'depth_B': depth_B,
            'A_paths': AB_path,
            'B_paths': AB_path
        }

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.AB_paths)
Пример #7
0
def run(dataset, option):

    # Load merge network
    opt = TestOptions().parse()
    global pix2pixmodel
    pix2pixmodel = Pix2Pix4DepthModel(opt)
    pix2pixmodel.save_dir = './pix2pix/checkpoints/mergemodel'
    pix2pixmodel.load_networks('latest')
    pix2pixmodel.eval()

    # Decide which depth estimation network to load
    if option.depthNet == 0:
        midas_model_path = "midas/model.pt"
        global midasmodel
        midasmodel = MidasNet(midas_model_path, non_negative=True)
        midasmodel.to(device)
        midasmodel.eval()
    elif option.depthNet == 1:
        global srlnet
        srlnet = DepthNet.DepthNet()
        srlnet = torch.nn.DataParallel(srlnet, device_ids=[0]).cuda()
        checkpoint = torch.load('structuredrl/model.pth.tar')
        srlnet.load_state_dict(checkpoint['state_dict'])
        srlnet.eval()
    elif option.depthNet == 2:
        global leresmodel
        leres_model_path = "res101.pth"
        checkpoint = torch.load(leres_model_path)
        leresmodel = RelDepthModel(backbone='resnext101')
        leresmodel.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."),
                                    strict=True)
        del checkpoint
        torch.cuda.empty_cache()
        leresmodel.to(device)
        leresmodel.eval()

    # Generating required directories
    result_dir = option.output_dir
    os.makedirs(result_dir, exist_ok=True)

    if option.savewholeest:
        whole_est_outputpath = option.output_dir + '_wholeimage'
        os.makedirs(whole_est_outputpath, exist_ok=True)

    if option.savepatchs:
        patchped_est_outputpath = option.output_dir + '_patchest'
        os.makedirs(patchped_est_outputpath, exist_ok=True)

    # Generate mask used to smoothly blend the local pathc estimations to the base estimate.
    # It is arbitrarily large to avoid artifacts during rescaling for each crop.
    mask_org = generatemask((3000, 3000))
    mask = mask_org.copy()

    # Value x of R_x defined in the section 5 of the main paper.
    r_threshold_value = 0.2
    if option.R0:
        r_threshold_value = 0
    elif option.R20:
        r_threshold_value = 0.2

    # Go through all images in input directory
    print("start processing")
    for image_ind, images in enumerate(dataset):
        print('processing image', image_ind, ':', images.name)

        # Load image from dataset
        img = images.rgb_image
        input_resolution = img.shape

        scale_threshold = 3  # Allows up-scaling with a scale up to 3

        # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the
        # supplementary material.
        whole_image_optimal_size, patch_scale = calculateprocessingres(img, option.net_receptive_field_size,
                                                                       r_threshold_value, scale_threshold,
                                                                       whole_size_threshold)

        print('\t wholeImage being processed in :', whole_image_optimal_size)

        # Generate the base estimate using the double estimation.
        whole_estimate = doubleestimate(img, option.net_receptive_field_size, whole_image_optimal_size,
                                        option.pix2pixsize, option.depthNet)
        if option.R0 or option.R20:
            path = os.path.join(result_dir, images.name)
            if option.output_resolution == 1:
                midas.utils.write_depth(path, cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]),
                                                         interpolation=cv2.INTER_CUBIC),
                                        bits=2, colored=option.colorize_results)
            else:
                midas.utils.write_depth(path, whole_estimate, bits=2, colored=option.colorize_results)
            continue

        # Output double estimation if required
        if option.savewholeest:
            path = os.path.join(whole_est_outputpath, images.name)
            if option.output_resolution == 1:
                midas.utils.write_depth(path,
                                        cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]),
                                                   interpolation=cv2.INTER_CUBIC), bits=2,
                                        colored=option.colorize_results)
            else:
                midas.utils.write_depth(path, whole_estimate, bits=2, colored=option.colorize_results)

        # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select
        # small high-density regions of the image.
        global factor
        factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2)
        print('Adjust factor is:', 1/factor)

        # Check if Local boosting is beneficial.
        if option.max_res < whole_image_optimal_size:
            print("No Local boosting. Specified Max Res is smaller than R20")
            path = os.path.join(result_dir, images.name)
            if option.output_resolution == 1:
                midas.utils.write_depth(path,
                                        cv2.resize(whole_estimate,
                                                   (input_resolution[1], input_resolution[0]),
                                                   interpolation=cv2.INTER_CUBIC), bits=2,
                                        colored=option.colorize_results)
            else:
                midas.utils.write_depth(path, whole_estimate, bits=2,
                                        colored=option.colorize_results)
            continue

        # Compute the default target resolution.
        if img.shape[0] > img.shape[1]:
            a = 2 * whole_image_optimal_size
            b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0])
        else:
            a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1])
            b = 2 * whole_image_optimal_size
        b = int(round(b / factor))
        a = int(round(a / factor))

        # recompute a, b and saturate to max res.
        if max(a,b) > option.max_res:
            print('Default Res is higher than max-res: Reducing final resolution')
            if img.shape[0] > img.shape[1]:
                a = option.max_res
                b = round(option.max_res * img.shape[1] / img.shape[0])
            else:
                a = round(option.max_res * img.shape[0] / img.shape[1])
                b = option.max_res
            b = int(b)
            a = int(a)

        img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC)

        # Extract selected patches for local refinement
        base_size = option.net_receptive_field_size*2
        patchset = generatepatchs(img, base_size)

        print('Target resolution: ', img.shape)

        # Computing a scale in case user prompted to generate the results as the same resolution of the input.
        # Notice that our method output resolution is independent of the input resolution and this parameter will only
        # enable a scaling operation during the local patch merge implementation to generate results with the same resolution
        # as the input.
        if option.output_resolution == 1:
            mergein_scale = input_resolution[0] / img.shape[0]
            print('Dynamicly change merged-in resolution; scale:', mergein_scale)
        else:
            mergein_scale = 1

        imageandpatchs = ImageandPatchs(option.data_dir, images.name, patchset, img, mergein_scale)
        whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale),
                                            round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC)
        imageandpatchs.set_base_estimate(whole_estimate_resized.copy())
        imageandpatchs.set_updated_estimate(whole_estimate_resized.copy())

        print('\t Resulted depthmap res will be :', whole_estimate_resized.shape[:2])
        print('patchs to process: '+str(len(imageandpatchs)))

        # Enumerate through all patches, generate their estimations and refining the base estimate.
        for patch_ind in range(len(imageandpatchs)):
            
            # Get patch information
            patch = imageandpatchs[patch_ind] # patch object
            patch_rgb = patch['patch_rgb'] # rgb patch
            patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base
            rect = patch['rect'] # patch size and location
            patch_id = patch['id'] # patch ID
            org_size = patch_whole_estimate_base.shape # the original size from the unscaled input
            print('\t processing patch', patch_ind, '|', rect)

            # We apply double estimation for patches. The high resolution value is fixed to twice the receptive
            # field size of the network for patches to accelerate the process.
            patch_estimation = doubleestimate(patch_rgb, option.net_receptive_field_size, option.patch_netsize,
                                              option.pix2pixsize, option.depthNet)

            # Output patch estimation if required
            if option.savepatchs:
                path = os.path.join(patchped_est_outputpath, imageandpatchs.name + '_{:04}'.format(patch_id))
                midas.utils.write_depth(path, patch_estimation, bits=2, colored=option.colorize_results)

            patch_estimation = cv2.resize(patch_estimation, (option.pix2pixsize, option.pix2pixsize),
                                          interpolation=cv2.INTER_CUBIC)

            patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (option.pix2pixsize, option.pix2pixsize),
                                                   interpolation=cv2.INTER_CUBIC)

            # Merging the patch estimation into the base estimate using our merge network:
            # We feed the patch estimation and the same region from the updated base estimate to the merge network
            # to generate the target estimate for the corresponding region.
            pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation)

            # Run merging network
            pix2pixmodel.test()
            visuals = pix2pixmodel.get_current_visuals()

            prediction_mapped = visuals['fake_B']
            prediction_mapped = (prediction_mapped+1)/2
            prediction_mapped = prediction_mapped.squeeze().cpu().numpy()

            mapped = prediction_mapped

            # We use a simple linear polynomial to make sure the result of the merge network would match the values of
            # base estimate
            p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1)
            merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape)

            merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC)

            # Get patch size and location
            w1 = rect[0]
            h1 = rect[1]
            w2 = w1 + rect[2]
            h2 = h1 + rect[3]

            # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size
            # and resize it to our needed size while merging the patches.
            if mask.shape != org_size:
                mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR)

            tobemergedto = imageandpatchs.estimation_updated_image

            # Update the whole estimation:
            # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless
            # blending at the boundaries of the patch region.
            tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask)
            imageandpatchs.set_updated_estimate(tobemergedto)

        # Output the result
        path = os.path.join(result_dir, imageandpatchs.name)
        if option.output_resolution == 1:
            midas.utils.write_depth(path,
                                    cv2.resize(imageandpatchs.estimation_updated_image,
                                               (input_resolution[1], input_resolution[0]),
                                               interpolation=cv2.INTER_CUBIC), bits=2, colored=option.colorize_results)
        else:
            midas.utils.write_depth(path, imageandpatchs.estimation_updated_image, bits=2, colored=option.colorize_results)

    print("finished")