def paste(self): if self.copy_mask is None: self.show_msg('Please select a region to copy first.') return copy_to = renormalize.from_url(self.editing_canvas.mask, target='pt', size=(256, 256))[0] area = renormalize.from_url(self.copy_mask, target='pt', size=(256, 256))[0] t, l, b, r = positive_bounding_box(area) area = area[t:b, l:r] target_rgb = self.copy_canvas.rgb source_rgb = renormalize.from_url(self.editing_canvas.image).permute( 1, 2, 0) rendered_img = paste_clip_at_center(source_rgb, target_rgb, centered_location(copy_to), area)[0].permute(2, 0, 1) self.editing_canvas.mask = '' self.editing_canvas.image = renormalize.as_url(rendered_img) self.positive_masks[self.editing_canvas.index] += copy_to self.real_images_array[ self.editing_canvas.index].src = renormalize.as_url( F.interpolate(rendered_img.unsqueeze(dim=0), size=(self.size, self.size)).squeeze()) self.from_editing_canvas()
def load(self): if self.editname_textbox.value == '': self.show_msg('Please enter a file name to load') return savedir = os.path.join(self.savedir, self.editname_textbox.value) if not os.path.exists(savedir): self.show_msg(f'{savedir} does not exist') return with open(os.path.join(savedir, 'edit_type.txt')) as f: self.edit_type = f.readlines()[0].strip() trn = transforms.ToTensor() for i in range(self.num_canvases): if os.path.exists(os.path.join(savedir, f'{i}_rgb.png')): image = trn(Image.open(os.path.join(savedir, f'{i}_rgb.png'))) * 2 - 1 self.real_canvas_array[i].image = renormalize.as_url(image) self.real_canvas_array[i].resized_image = renormalize.as_url( F.interpolate(image.unsqueeze(dim=0), size=(self.size, self.size)).squeeze()) self.real_images_array[i].src = self.real_canvas_array[ i].resized_image if os.path.exists(os.path.join(savedir, f'{i}_pos.pt')): self.positive_masks[i] = torch.load( os.path.join(savedir, f'{i}_pos.pt')) if os.path.exists(os.path.join(savedir, f'{i}_neg.pt')): self.real_canvas_array[i].negative_mask = torch.load( os.path.join(savedir, f'{i}_neg.pt'))
def display_addition_instance(self): for i in range(12): self.addition_instances_array[i].src = renormalize.as_url( self.trn( Image.open( os.path.join(self.parentdir, 'instances', '{:03d}.png'.format(i)))) * 2 - 1) self.addition_instances_array[i].index = i
def update_cb(i, rgb): if update: img = torch.tensor(rgb).permute(2, 0, 1) * 2 - 1 resized = F.interpolate(img.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) self.real_images_array[i].src = renormalize.as_url(resized) else: pass
def render_editing_canvas(self, style): index = self.editing_canvas.index pose = self.poses[index].unsqueeze(dim=0) self.editing_canvas.image = renormalize.as_url( self.render(pose, style, verbose=False, inds=[index], use_cache=self.edit_type == 'color_from', update_cache=False)[0])
def target_transfer(self, instancenum, index): self.copy_canvas.mask = '' self.copy_canvas.index = index self.copy_canvas.instance_style = self.all_instance_styles[ instancenum].unsqueeze(dim=0) rgb = self.render(self.poses[index].unsqueeze(dim=0), self.copy_canvas.instance_style.squeeze(dim=0), verbose=False, use_cache=False) self.copy_canvas.image = renormalize.as_url( F.interpolate(rgb, size=(self.size, self.size))[0])
def update_canvas(self, images, disps=None): for i, image in enumerate(images): resized_rgb = F.interpolate(image.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) self.real_images_array[i].src = renormalize.as_url(resized_rgb) self.real_canvas_array[i].image = renormalize.as_url(image) self.real_canvas_array[i].resized_image = renormalize.as_url( resized_rgb) if disps is not None: disp_img = torch.from_numpy(to8b(to_disp_img( disps[i]))).unsqueeze(dim=0) / 255. resized_disp = F.interpolate(disp_img.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) self.real_canvas_array[i].resized_disp = renormalize.as_url( resized_disp) self.real_canvas_array[i].disp = renormalize.as_url(disp_img) if self.editing_canvas.index >= 0: self.editing_canvas.image = self.real_canvas_array[ self.editing_canvas.index].image
def change_mask(self, ev): if self.mask_type == 'positive' or self.mask_type == 'sigma': i = self.editing_canvas.index orig_img = renormalize.from_url(self.editing_canvas.image) mask = renormalize.from_url(self.editing_canvas.mask) / 2 + 0.5 mask = F.interpolate(mask.unsqueeze(dim=0), size=(self.size * 2, self.size * 2)).squeeze() if self.mask_type == 'positive': self.edit_type = 'color' if self.color is None: self.show_msg('Please select a color.') if ev.target.image != '': self.real_canvas_array[ ev.target.index].negative_mask = '' return edited_img = orig_img * (1 - mask) + mask * self.color elif self.mask_type == 'sigma': self.edit_type = 'removal' edited_img = orig_img * (1 - mask) + mask * torch.ones( (3, self.size * 2, self.size * 2)).to(mask.device) self.positive_masks[i] += mask self.real_canvas_array[i].image = renormalize.as_url(edited_img) self.editing_canvas.image = renormalize.as_url(edited_img) self.real_images_array[i].src = renormalize.as_url( F.interpolate(edited_img.unsqueeze(dim=0), size=(self.size, self.size)).squeeze()) self.editing_canvas.mask = '' elif self.mask_type == 'negative': i = ev.target.index self.real_canvas_array[i].negative_mask = self.editing_canvas.mask elif self.copy_mask is not None: self.paste() else: if ev.target.image != '': self.real_canvas_array[ev.target.index].negative_mask = ''
def __init__(self, instance, config, use_cached=True, expname=None, edit_type=None, num_canvases=9, shape_params='fusion_shape_branch', color_params='color_branch', randneg=8192, device='cuda:0'): super().__init__(style=dict( border="3px solid gray", padding="8px", display="inline-block")) torch.set_default_tensor_type('torch.cuda.FloatTensor' if device == 'cuda:0' else 'cpu') self.edit_type = edit_type self.instance = instance self.num_canvases = num_canvases self.shape_params = shape_params self.color_params = color_params self.size = IMG_SIZE self.randneg = randneg self.device = device self.msg_out = labwidget.Div() self.editing_canvas = paintwidget.PaintWidget(image='', width=self.size * 3, height=self.size * 3).on( 'mask', self.change_mask) self.editing_canvas.index = -1 self.copy_canvas = paintwidget.PaintWidget(image='', width=self.size * 2, height=self.size * 2).on( 'mask', self.copy) self.copy_mask = None inline = dict(display='inline', border="2px solid gray") self.toggle_rgbs_disps_btn = labwidget.Button( 'show depth', style=inline).on('click', self.toggle_rgb_disps) self.positive_mask_btn = labwidget.Button(self.pad('edit color'), style=inline).on( 'click', self.positive_mask) self.addition_mask_btn = labwidget.Button(self.pad('add shape'), style=inline).on( 'click', self.add) self.sigma_mask_btn = labwidget.Button(self.pad('remove shape'), style=inline).on( 'click', self.sigma_mask) self.color_from_btn = labwidget.Button(self.pad('transfer color'), style=inline).on( 'click', self.color_from) self.shape_from_btn = labwidget.Button(self.pad('transfer shape'), style=inline).on( 'click', self.shape_from) self.execute_btn = labwidget.Button(self.pad('execute'), style=inline).on( 'click', self.execute_edit) self.brushsize_textbox = labwidget.Textbox(5, desc='brushsize: ', size=3).on( 'value', self.change_brushsize) self.target = None self.use_color_cache = True self.color_style = dict(display='inline', border="2px solid white") trn = transforms.Compose( [transforms.Resize(32), transforms.ToTensor()]) bg_img = trn(Image.open('bg.png').convert('RGB')) bg_img = renormalize.as_url(bg_img * 2 - 1) self.color_pallete = [ labwidget.Image(src=bg_img, style=self.color_style).on('click', self.set_color) ] self.color_pallete[-1].index = 0 self.color_pallete[-1].color_type = 'bg' for color in mean_colors.colors.values(): image = torch.zeros(3, 32, 32) image[0, :, :] = color[0] image[1, :, :] = color[1] image[2, :, :] = color[2] image = image / 255. * 2 - 1 self.color_pallete.append( labwidget.Image(src=renormalize.as_url(image), style=self.color_style).on( 'click', self.set_color)) self.color_pallete[-1].index = len(self.color_pallete) - 1 self.color_pallete[-1].color_type = 'color' # TODO: Highlight the white box with black for clarity self.color = None self.mask_type = None self.real_canvas_array = [] self.real_images_array = [] self.positive_masks = [] train, test, optimizer, styles = load_model(instance, config, expname=expname) poses, hwfs, cache, args = load_dataset(instance, config, num_canvases=num_canvases, N_instances=styles.shape[0], expname=expname, use_cached=use_cached) self.parentdir = load_config(config).expname self.expname = expname if expname else self.parentdir self.savedir = os.path.join(self.expname, str(instance)) os.makedirs(self.savedir, exist_ok=True) self.poses = poses.to(device) self.cache = cache self.chunk = args.chunk self.near = args.blender_near self.far = args.blender_far self.nfs = [[self.near, self.far]] * self.poses.shape[0] self.hwfs = hwfs.to(device) self.old_fine_network = dict( copy.deepcopy(test['network_fine']).named_parameters()) self.train_kwargs = train self.test_kwargs = test self.optimizer = None self.all_instance_styles = styles self.instance_style = styles[instance].unsqueeze(dim=0).to(device) if cache is not None: self.weights = cache['weights'] self.alphas = cache['alphas'] self.features = cache['features'] else: self.weights = None self.alphas = None self.features = None self.trn = transforms.Compose( [transforms.Resize(128), transforms.ToTensor()]) self.transfer_instances_array = [ labwidget.Image(src='').on('click', self.change_target) for _ in range(12) ] self.addition_instances_array = [ labwidget.Image(src='').on('click', self.change_target) for _ in range(12) ] images, disps = self.render(self.poses, self.instance_style, verbose=False, get_disps=True) for i, image in enumerate(images): resized = F.interpolate(image.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) disp_img = torch.from_numpy(to8b(to_disp_img( disps[i]))).unsqueeze(dim=0) / 255. resized_disp = F.interpolate(disp_img.unsqueeze(dim=0), size=(self.size, self.size)).squeeze(dim=0) self.real_images_array.append( labwidget.Image(src=renormalize.as_url(resized)).on( 'click', self.set_editing_canvas)) self.real_images_array[-1].index = i self.real_canvas_array.append( paintwidget.PaintWidget(image=renormalize.as_url(image), width=self.size * 3, height=self.size * 3).on( 'mask', self.change_mask)) self.real_canvas_array[-1].index = i self.real_canvas_array[-1].negative_mask = '' self.real_canvas_array[-1].resized_image = renormalize.as_url( resized) self.real_canvas_array[-1].resized_disp = renormalize.as_url( resized_disp) self.real_canvas_array[-1].disp = renormalize.as_url(disp_img) self.real_canvas_array[-1].orig = renormalize.as_url(image) self.positive_masks.append(torch.zeros(image.shape).cpu()) self.show_rgbs = True self.change_brushsize() self.editname_textbox = labwidget.Datalist(choices=self.saved_names(), style=inline) self.save_btn = labwidget.Button('save', style=inline).on('click', self.save) self.load_btn = labwidget.Button('load', style=inline).on('click', self.load)