Beispiel #1
0
def rotate_and_flip(tensor,device,p=0.5):
    for i in range(tensor.shape[0]):
        random_number=np.random.uniform()
        if random_number < p:
            center = torch.ones(tensor.shape[1], 2).to(device)
            center[:, 0] = tensor.shape[3] / 2  # x
            center[:, 1] = tensor.shape[2] / 2  # y
            #scale: torch.tensor = torch.ones(1)#*np.random.uniform(0.8,1.2)
            angle = torch.tensor([np.random.randint(-90,90,)*np.ones(tensor.shape[1])]).squeeze().to(device).float()
            #print(angle.shape)
            #print(tensor[i].shape)
            #M = kornia.get_rotation_matrix2d(center, angle, scale)#.to(device)
            #Mt = torch.ones((tensor.shape[0],2,3))
            #Mt[:] = M
            #Mt=Mt.to(device)
            #tensor[:,j]=kornia.warp_affine(tensor[:,j], Mt, dsize=(tensor.shape[3], tensor.shape[4]))
            #print(tensor.dtype)
            tensor[i]=kornia.rotate(tensor[i],angle,center)
        random_number=np.random.uniform()
        if random_number < p:
            tensor[i,:]=kornia.hflip(tensor[i,:])

        random_number=np.random.uniform()
        if random_number < p:
            tensor[i,:]=kornia.vflip(tensor[i,:])

    return tensor
Beispiel #2
0
    def _get_transformed_images(images, hflip):

        images_transformed = images

        if hflip:
            images_transformed = K.hflip(images_transformed)

        # Normalize
        images_transformed = K.normalize(images_transformed, 0.5, 0.5)

        return images_transformed
Beispiel #3
0
    def _get_transformed_frames(frames, hflip):

        frames_transformed = frames

        if hflip:
            frames_transformed = K.hflip(frames_transformed)

        # Normalize
        frames_transformed = K.normalize(frames_transformed, 0.5, 0.5)

        # Permute CTHW
        frames_transformed = frames_transformed.permute(1, 0, 2, 3)

        return frames_transformed
Beispiel #4
0
        def op_script(data: torch.Tensor) -> torch.Tensor:

            return kornia.hflip(data)
    def infer(self,
              img_path,
              cont_path=None,
              mode=None,
              color=None,
              text=None,
              mask_path=None,
              gt_path=None,
              output_dir=None):
        mode = self.opt.TEST.MODE if mode is None else mode
        text = self.opt.TEST.TEXT if text is None else text

        with torch.no_grad():
            im = Image.open(img_path).convert("RGB")
            im = im.resize((self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
            im_t = linear_scaling(
                transforms.ToTensor()(im).unsqueeze(0).cuda())

            if gt_path is not None:
                gt = Image.open(gt_path).convert("RGB")
                gt = gt.resize((self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))

            if mask_path is None:
                masks = torch.from_numpy(
                    self.mask_generator.generate(
                        self.opt.DATASET.SIZE,
                        self.opt.DATASET.SIZE)).float().cuda()
            else:
                masks = Image.open(mask_path).convert("L")
                masks = masks.resize(
                    (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
                masks = self.tensorize(masks).unsqueeze(0).float().cuda()

            if cont_path is not None:
                assert mode in [1, 5, 6, 7, 8]
                c_im = Image.open(cont_path).convert("RGB")
                c_im = c_im.resize(
                    (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
                if mode == 6:
                    c_im = c_im.resize((self.opt.DATASET.SIZE // 8,
                                        self.opt.DATASET.SIZE // 8))
                    c_im_t = self.tensorize(c_im).unsqueeze(0).cuda()
                    r_c_im_t = torch.zeros((1, 3, self.opt.DATASET.SIZE,
                                            self.opt.DATASET.SIZE)).cuda()
                    masks = torch.zeros((1, 1, self.opt.DATASET.SIZE,
                                         self.opt.DATASET.SIZE)).cuda()
                    for i in range(1):
                        coord_x, coord_y = np.random.randint(
                            self.opt.DATASET.SIZE - self.opt.DATASET.SIZE // 8,
                            size=(2, ))
                        r_c_im_t[:, :, coord_x:coord_x + c_im_t.size(2),
                                 coord_y:coord_y + c_im_t.size(3)] = c_im_t
                        masks[:, :, coord_x:coord_x + c_im_t.size(2),
                              coord_y:coord_y +
                              c_im_t.size(3)] = torch.ones_like(c_im_t[0, 0])
                    c_im_t = linear_scaling(r_c_im_t)
                elif mode == 7:
                    mask = self.to_pil(
                        torch.zeros(
                            (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)))
                    d = ImageDraw.Draw(c_im)
                    d_m = ImageDraw.Draw(mask)
                    font = ImageFont.truetype(self.opt.TEST.FONT,
                                              self.opt.TEST.FONT_SIZE)
                    font_w, font_h = d.textsize(text, font=font)
                    c_w = (self.opt.DATASET.SIZE - font_w) // 2
                    c_h = (self.opt.DATASET.SIZE - font_h) // 2
                    d.text((c_w, c_h),
                           text,
                           font=font,
                           fill=tuple([
                               int(a * 255)
                               for a in COLORS["{}".format(color).upper()]
                           ]))
                    d_m.text((c_w, c_h), text, font=font, fill=255)
                    masks = self.tensorize(mask).unsqueeze(0).float().cuda()
                    c_im_t = linear_scaling(self.tensorize(c_im).cuda())
                elif mode == 8:
                    center_cropper = transforms.CenterCrop(
                        (self.opt.DATASET.SIZE // 2,
                         self.opt.DATASET.SIZE // 2))
                    crop = self.tensorize(center_cropper(c_im))
                    coord_x = coord_y = (self.opt.DATASET.SIZE - 128) // 2
                    r_c_im_t = torch.zeros((1, 3, self.opt.DATASET.SIZE,
                                            self.opt.DATASET.SIZE)).cuda()
                    r_c_im_t[:, :, coord_x:coord_x + 128,
                             coord_y:coord_y + 128] = crop
                    if mask_path is None:
                        tmp = kornia.resize(masks, self.opt.DATASET.SIZE // 2)
                        masks = torch.zeros((1, 1, self.opt.DATASET.SIZE,
                                             self.opt.DATASET.SIZE)).cuda()
                        masks[:, :,
                              coord_x:coord_x + self.opt.DATASET.SIZE // 2,
                              coord_y:coord_y +
                              self.opt.DATASET.SIZE // 2] = tmp
                        tmp = kornia.hflip(tmp)
                        masks[:, :,
                              coord_x:coord_x + self.opt.DATASET.SIZE // 2,
                              coord_y:coord_y +
                              self.opt.DATASET.SIZE // 2] += tmp
                        # tmp = kornia.vflip(tmp)
                        # masks[:, :, coord_x:coord_x + self.opt.DATASET.SIZE // 2, coord_y:coord_y + self.opt.DATASET.SIZE // 2] += tmp
                        masks = torch.clamp(masks, min=0., max=1.)
                    c_im_t = linear_scaling(r_c_im_t)
                else:
                    c_im_t = linear_scaling(
                        transforms.ToTensor()(c_im).unsqueeze(0).cuda())
            else:
                assert mode in [2, 3, 4]
                if mode == 2:
                    c_im_t = linear_scaling(torch.rand_like(im_t))
                elif mode == 3:
                    color = self.opt.TEST.BRUSH_COLOR if color is None else color
                    brush = torch.tensor(
                        list(COLORS["{}".format(color).upper()])).unsqueeze(
                            0).unsqueeze(-1).unsqueeze(-1).cuda()
                    c_im_t = linear_scaling(torch.ones_like(im_t) * brush)
                elif mode == 4:
                    c_im_t = im_t

            if (mask_path is None or mode == 5) and mode != 8:
                smooth_masks = self.mask_smoother(1 - masks) + masks
                smooth_masks = torch.clamp(smooth_masks, min=0., max=1.)
            else:
                smooth_masks = masks

            masked_imgs = c_im_t * smooth_masks + im_t * (1. - smooth_masks)
            pred_masks, neck = self.mpn(masked_imgs)
            pred_masks = pred_masks if mode != 8 else torch.clamp(
                pred_masks * smooth_masks, min=0., max=1.)
            masked_imgs_embraced = masked_imgs * (1. - pred_masks)
            output = self.rin(masked_imgs_embraced, pred_masks, neck)
            output = torch.clamp(output, max=1., min=0.)

            if output_dir is not None:
                # output_dir = os.path.join(output_dir, self.ablation_map[mode])
                # os.makedirs(output_dir, exist_ok=True)
                if mode == 8:
                    self.to_pil(
                        torch.cat([
                            linear_unscaling(im_t).squeeze().cpu(),
                            self.tensorize(c_im).squeeze().cpu(),
                            linear_unscaling(masked_imgs).squeeze().cpu(),
                            output.squeeze().cpu()
                        ],
                                  dim=2)).save(
                                      os.path.join(
                                          output_dir, "out{}_{}_{}".format(
                                              mode, color,
                                              img_path.split("/")[-1])))
                else:
                    self.to_pil(
                        torch.cat([
                            linear_unscaling(masked_imgs).squeeze().cpu(),
                            output.squeeze().cpu()
                        ],
                                  dim=1)).save(
                                      os.path.join(
                                          output_dir, "out{}_{}_{}".format(
                                              mode, color,
                                              img_path.split("/")[-1])))
            else:
                self.to_pil(output.squeeze().cpu()).show()
                self.to_pil(pred_masks.squeeze().cpu()).show()
                self.to_pil(
                    linear_unscaling(masked_imgs).squeeze().cpu()).show()
                self.to_pil(smooth_masks.squeeze().cpu()).show()
                self.to_pil(linear_unscaling(im_t).squeeze().cpu()).show()
                if gt_path is not None:
                    gt.show()