Exemplo n.º 1
0
    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        AB_path = self.AB_paths[index % self.dataset_len]
        try:
            AB = Image.open(AB_path).convert('RGB')
        except Exception as e:
            return self.__getitem__(index + 1)
        # mask_path = self.mask_paths[index % self.dataset_len]
        # with open(mask_path, "r") as fr:
        #     mask_data = json.loads(fr.read())

        # split AB image into A and B
        w, h = AB.size
        w3 = int(w / 3)
        A = AB.crop((0, 0, w3, h))
        # B = AB.crop((w3, 0, 2 * w3, h)).convert("L")
        B = AB.crop((w3, 0, 2 * w3, h))
        mask = AB.crop((2 * w3, 0, w, h))
        mask, g, b = mask.split()

        # crop_boxes = _handler_mask_data(mask_data)
        # if len(crop_boxes) > 0:
        #     crop_box = random.choice(crop_boxes)
        # else:
        #     xs = random.randint(0, w3 - 256)
        #     ys = random.randint(0, h - 256)
        #     crop_box = np.array([xs, ys, xs + 256, ys + 256])
        current_transform = default_transform()
        gray_transform = default_transform(grayscale=True)

        A = current_transform(A)
        B = current_transform(B)
        noise_mask = gray_transform(mask)
        # 初期训练的时候 mask_img > 0 ,并且不将mask填充. (这里对mask的处理0.2其实可以作为一个随机数)
        # 到后期训练的时候需要将mask在input中填充(这是为了捕捉到更多结构信息,而非颜色信息)
        # 网路需要学到颜色信息。
        # normalize 防止如果包含resize的情况, 插值的问题
        # noise_mask[noise_mask >= -0.2] = 1.0
        # noise_mask[noise_mask < -0.2] = -1.0
        return {'A': A, 'B': B, "noise_mask": noise_mask,
                "AB_path": AB_path}
Exemplo n.º 2
0
    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        AB_path = self.AB_paths[index % self.dataset_len]
        try:
            AB = Image.open(AB_path).convert('RGB')
        except Exception as e:
            return self.__getitem__(index + 1)

        # split AB image into A and B
        w, h = AB.size
        w2 = int(w / 2)
        A = AB.crop((0, 0, w2, h))
        # B = AB.crop((w3, 0, 2 * w3, h)).convert("L")
        B = AB.crop((w2, 0, w, h))

        current_transform = default_transform()

        A = current_transform(A)
        B = current_transform(B)
        return {'A': A, 'B': B,
                "AB_path": AB_path}
Exemplo n.º 3
0
    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        AB_path = self.AB_paths[index % self.dataset_len]
        try:
            AB = Image.open(AB_path).convert('RGB')
        except Exception as e:
            return self.__getitem__(index + 1)

        # split AB image into A and B
        w, h = AB.size
        w3 = int(w / 3)
        blurry_img = AB.crop((0, 0, w3, h))
        # B = AB.crop((w3, 0, 2 * w3, h)).convert("L")
        clear_img = AB.crop((w3, 0, 2 * w3, h))

        current_transform = default_transform()

        data = {}
        blur_key = "blur"
        clear_key = "clear"

        for i in range(self.n_levels):
            if i > 0:
                blur = blurry_img
                clear = clear_img
            else:
                scale = self.scale**(i + 1)
                nw = int(w3 * scale)
                nh = int(h * scale)
                blur = blurry_img.resize((nw, nh), resample=self.resample)
                clear = clear_img.resize((nw, nh), resample=self.resample)
            data.update({
                "{}{}".format(blur_key, i): blur,
                "{}{}".format(clear_key, i): clear,
            })

        return data
Exemplo n.º 4
0
 def __init__(self, config):
     self.experiment_name = config.pop('name')
     self.random_seed = config.get('random_seed', 30)
     model_name = get_model_name(config["arch"].pop("type"))
     self.model = getattr(models, model_name)(config)
     logger.info("model init success")
     self.transform = default_transform()
     self.long_side = config["predictor"]["long_side"]
     if "test_img_dir" in config["predictor"]:
         self.batch_flag = True
         test_img_dir = config["predictor"]["test_img_dir"]
         self.input_img_paths = get_file_list(test_img_dir, p_postfix=['.jpg', '.png', ".tif", ".JPG", ".jpeg"])
     else:
         self.batch_flag = False
     out_img_dir = config["predictor"]["out_img_dir"]
     os.makedirs(out_img_dir, exist_ok=True)
     self.out_img_dir = out_img_dir
     self.combine_res = config["predictor"].get("combine_res", True)
     self.model.set_mode()