def rotate_and_flip(tensor,device,p=0.5): for i in range(tensor.shape[0]): random_number=np.random.uniform() if random_number < p: center = torch.ones(tensor.shape[1], 2).to(device) center[:, 0] = tensor.shape[3] / 2 # x center[:, 1] = tensor.shape[2] / 2 # y #scale: torch.tensor = torch.ones(1)#*np.random.uniform(0.8,1.2) angle = torch.tensor([np.random.randint(-90,90,)*np.ones(tensor.shape[1])]).squeeze().to(device).float() #print(angle.shape) #print(tensor[i].shape) #M = kornia.get_rotation_matrix2d(center, angle, scale)#.to(device) #Mt = torch.ones((tensor.shape[0],2,3)) #Mt[:] = M #Mt=Mt.to(device) #tensor[:,j]=kornia.warp_affine(tensor[:,j], Mt, dsize=(tensor.shape[3], tensor.shape[4])) #print(tensor.dtype) tensor[i]=kornia.rotate(tensor[i],angle,center) random_number=np.random.uniform() if random_number < p: tensor[i,:]=kornia.hflip(tensor[i,:]) random_number=np.random.uniform() if random_number < p: tensor[i,:]=kornia.vflip(tensor[i,:]) return tensor
def _get_transformed_images(images, hflip): images_transformed = images if hflip: images_transformed = K.hflip(images_transformed) # Normalize images_transformed = K.normalize(images_transformed, 0.5, 0.5) return images_transformed
def _get_transformed_frames(frames, hflip): frames_transformed = frames if hflip: frames_transformed = K.hflip(frames_transformed) # Normalize frames_transformed = K.normalize(frames_transformed, 0.5, 0.5) # Permute CTHW frames_transformed = frames_transformed.permute(1, 0, 2, 3) return frames_transformed
def op_script(data: torch.Tensor) -> torch.Tensor: return kornia.hflip(data)
def infer(self, img_path, cont_path=None, mode=None, color=None, text=None, mask_path=None, gt_path=None, output_dir=None): mode = self.opt.TEST.MODE if mode is None else mode text = self.opt.TEST.TEXT if text is None else text with torch.no_grad(): im = Image.open(img_path).convert("RGB") im = im.resize((self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)) im_t = linear_scaling( transforms.ToTensor()(im).unsqueeze(0).cuda()) if gt_path is not None: gt = Image.open(gt_path).convert("RGB") gt = gt.resize((self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)) if mask_path is None: masks = torch.from_numpy( self.mask_generator.generate( self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)).float().cuda() else: masks = Image.open(mask_path).convert("L") masks = masks.resize( (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)) masks = self.tensorize(masks).unsqueeze(0).float().cuda() if cont_path is not None: assert mode in [1, 5, 6, 7, 8] c_im = Image.open(cont_path).convert("RGB") c_im = c_im.resize( (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)) if mode == 6: c_im = c_im.resize((self.opt.DATASET.SIZE // 8, self.opt.DATASET.SIZE // 8)) c_im_t = self.tensorize(c_im).unsqueeze(0).cuda() r_c_im_t = torch.zeros((1, 3, self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)).cuda() masks = torch.zeros((1, 1, self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)).cuda() for i in range(1): coord_x, coord_y = np.random.randint( self.opt.DATASET.SIZE - self.opt.DATASET.SIZE // 8, size=(2, )) r_c_im_t[:, :, coord_x:coord_x + c_im_t.size(2), coord_y:coord_y + c_im_t.size(3)] = c_im_t masks[:, :, coord_x:coord_x + c_im_t.size(2), coord_y:coord_y + c_im_t.size(3)] = torch.ones_like(c_im_t[0, 0]) c_im_t = linear_scaling(r_c_im_t) elif mode == 7: mask = self.to_pil( torch.zeros( (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))) d = ImageDraw.Draw(c_im) d_m = ImageDraw.Draw(mask) font = ImageFont.truetype(self.opt.TEST.FONT, self.opt.TEST.FONT_SIZE) font_w, font_h = d.textsize(text, font=font) c_w = (self.opt.DATASET.SIZE - font_w) // 2 c_h = (self.opt.DATASET.SIZE - font_h) // 2 d.text((c_w, c_h), text, font=font, fill=tuple([ int(a * 255) for a in COLORS["{}".format(color).upper()] ])) d_m.text((c_w, c_h), text, font=font, fill=255) masks = self.tensorize(mask).unsqueeze(0).float().cuda() c_im_t = linear_scaling(self.tensorize(c_im).cuda()) elif mode == 8: center_cropper = transforms.CenterCrop( (self.opt.DATASET.SIZE // 2, self.opt.DATASET.SIZE // 2)) crop = self.tensorize(center_cropper(c_im)) coord_x = coord_y = (self.opt.DATASET.SIZE - 128) // 2 r_c_im_t = torch.zeros((1, 3, self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)).cuda() r_c_im_t[:, :, coord_x:coord_x + 128, coord_y:coord_y + 128] = crop if mask_path is None: tmp = kornia.resize(masks, self.opt.DATASET.SIZE // 2) masks = torch.zeros((1, 1, self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)).cuda() masks[:, :, coord_x:coord_x + self.opt.DATASET.SIZE // 2, coord_y:coord_y + self.opt.DATASET.SIZE // 2] = tmp tmp = kornia.hflip(tmp) masks[:, :, coord_x:coord_x + self.opt.DATASET.SIZE // 2, coord_y:coord_y + self.opt.DATASET.SIZE // 2] += tmp # tmp = kornia.vflip(tmp) # masks[:, :, coord_x:coord_x + self.opt.DATASET.SIZE // 2, coord_y:coord_y + self.opt.DATASET.SIZE // 2] += tmp masks = torch.clamp(masks, min=0., max=1.) c_im_t = linear_scaling(r_c_im_t) else: c_im_t = linear_scaling( transforms.ToTensor()(c_im).unsqueeze(0).cuda()) else: assert mode in [2, 3, 4] if mode == 2: c_im_t = linear_scaling(torch.rand_like(im_t)) elif mode == 3: color = self.opt.TEST.BRUSH_COLOR if color is None else color brush = torch.tensor( list(COLORS["{}".format(color).upper()])).unsqueeze( 0).unsqueeze(-1).unsqueeze(-1).cuda() c_im_t = linear_scaling(torch.ones_like(im_t) * brush) elif mode == 4: c_im_t = im_t if (mask_path is None or mode == 5) and mode != 8: smooth_masks = self.mask_smoother(1 - masks) + masks smooth_masks = torch.clamp(smooth_masks, min=0., max=1.) else: smooth_masks = masks masked_imgs = c_im_t * smooth_masks + im_t * (1. - smooth_masks) pred_masks, neck = self.mpn(masked_imgs) pred_masks = pred_masks if mode != 8 else torch.clamp( pred_masks * smooth_masks, min=0., max=1.) masked_imgs_embraced = masked_imgs * (1. - pred_masks) output = self.rin(masked_imgs_embraced, pred_masks, neck) output = torch.clamp(output, max=1., min=0.) if output_dir is not None: # output_dir = os.path.join(output_dir, self.ablation_map[mode]) # os.makedirs(output_dir, exist_ok=True) if mode == 8: self.to_pil( torch.cat([ linear_unscaling(im_t).squeeze().cpu(), self.tensorize(c_im).squeeze().cpu(), linear_unscaling(masked_imgs).squeeze().cpu(), output.squeeze().cpu() ], dim=2)).save( os.path.join( output_dir, "out{}_{}_{}".format( mode, color, img_path.split("/")[-1]))) else: self.to_pil( torch.cat([ linear_unscaling(masked_imgs).squeeze().cpu(), output.squeeze().cpu() ], dim=1)).save( os.path.join( output_dir, "out{}_{}_{}".format( mode, color, img_path.split("/")[-1]))) else: self.to_pil(output.squeeze().cpu()).show() self.to_pil(pred_masks.squeeze().cpu()).show() self.to_pil( linear_unscaling(masked_imgs).squeeze().cpu()).show() self.to_pil(smooth_masks.squeeze().cpu()).show() self.to_pil(linear_unscaling(im_t).squeeze().cpu()).show() if gt_path is not None: gt.show()