예제 #1
0
    def __getitem__(self, index):
        """Return a data point and its metadata information.
        Parameters:
            index - - a random integer for data indexing
        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        AB_path = self.AB_paths[index]
        AB = Image.open(AB_path).convert('RGB')
        # split AB image into A and B
        w, h = AB.size
        w2 = int(w / 2)
        A = AB.crop((0, 0, w2, h))
        B = AB.crop((w2, 0, w, h))

        # apply the same transform to both A and B
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt,
                                    transform_params,
                                    grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt,
                                    transform_params,
                                    grayscale=(self.output_nc == 1))

        A = A_transform(A)
        B = B_transform(B)

        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
    def __getitem__(self, index):
        """Return a data point and its metadata information.
        Parameters:
            index - - a random integer for data indexing
        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        # AB_path = self.AB_paths[index]
        # AB = Image.open(AB_path).convert('RGB')

        if self.opt.phase == 'train':
            AB_path = self.train_split[index]
        elif self.opt.phase == 'val':
            AB_path = self.val_split[index]

        # split AB image into A and B
        AB = self.img_dict[AB_path]
        w, h = AB.size
        w2 = int(w / 2)
        A = AB.crop((0, 0, w2, h))
        B = AB.crop((w2, 0, w, h))
        
        # apply the same transform to both A and B
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))

        A = A_transform(A)
        B = B_transform(B)
        
        cls_label = self.cat_contiugous_ids[int(os.path.basename(AB_path).strip().split('_')[0])]
        cat_embedding_tensor = self.cat_embeding_list[cls_label]

        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path,'cls_label':cls_label, "cat_emb":cat_embedding_tensor}
    def data_process(self, src_img, cat_embedding, style_idx):

        # src img transform
        src_img = src_img.convert("RGB")
        transform_params = get_params(self.option, src_img.size)
        img_transform = get_transform(self.option, transform_params, grayscale=False)
        src_img = img_transform(src_img)

        src_img = src_img.unsqueeze(0).cuda()
        cat_lbl = style_idx
        # cat_lbl = torch.tensor(style_idx).view(1,-1).cuda()
        cat_embedding = cat_embedding.unsqueeze(0).cuda()
        self.src_img = src_img

        return {"B": src_img, "cls_label": cat_lbl, "cat_emb": cat_embedding}
    def cls_process_imgs(self, imgs):
        output_list = []
        for im in imgs:
            im = im.convert('RGB')

            transform_params = get_params(self.option, im.size)
            img_transform = get_transform(self.option, transform_params, grayscale=False)
            im = img_transform(im)

            im = im.unsqueeze(0).cuda()
            output_list.append(im)
        
        assert len(output_list) == len(imgs)

        return output_list