def repaint_canvas_array(self): level = (self.query_rq.quantiles(0.999)[0] if (self.query_vis and self.query_rq) else None) if self.show_original: saved_state_dict = copy.deepcopy(self.gw.model.state_dict()) with torch.no_grad(): self.gw.model.load_state_dict(self.original_model.state_dict()) images = self.gw.render_image_batch(self.sel, self.query_key if self.query_vis else None, level, border_color=[255, 255, 255]) #customize change f = open(self.savedir+'edited.pkl', 'wb') pickle.dump(images, f) f.close() if self.show_original: with torch.no_grad(): self.gw.model.load_state_dict(saved_state_dict) for i, canvas in enumerate(self.canvas_array): canvas.mask = '' size = (canvas.height, canvas.width) if canvas.height else None if i < len(self.sel): canvas.image = renormalize.as_url(images[i], size=size) else: canvas.image = ''
def old_repaint_key_tray(self): if 'key' not in self.request or len(self.request['key']) == 0: self.keytray_div.style = {'display': 'none'} return keymasks = OrderedDict([(imgnum, mask) for imgnum, mask in self.request['key']]) self.keytray_menu.choices = list(keymasks.keys()) self.keytray_div.style = {'display': 'block'} if self.keytray_menu.selection is None or (int( self.keytray_menu.selection) not in keymasks): sel = int(self.keytray_menu.choices[-1]) self.keytray_menu.selection = sel else: sel = int(self.keytray_menu.selection) self.keytray_canvas.image = renormalize.as_url( self.gw.render_image(sel, None)) self.keytray_canvas.mask = keymasks[sel]
def exec_object(self, obj_acts=None, obj_area=None, obj_output=None, bounds=None): if obj_acts is None: imgnum, mask = self.request['object'] obj_acts, obj_output, obj_area, bounds = ( self.gw.object_from_selection(imgnum, mask)) if obj_area is None: obj_area = torch.ones(obj_acts.shape[2:], device=obj_acts.device) mask = None self.obj_acts, self.obj_area = obj_acts, obj_area cropped_out = self.gw.merge_target_output(obj_output, obj_acts, bounds) imgout = self.gw.render_object(cropped_out, obj_area) imgviz.ImageVisualizer((imgout.height, imgout.width)) self.copy_canvas.image = renormalize.as_url( self.request_mask(thickness=3)) self.copy_canvas.mask = mask self.show_msg('picked object')
def exec_paste(self): imgnum, mask = self.request['paste'] goal_in, goal_out, viz_out, bounds = self.gw.paste_from_selection( imgnum, mask, self.obj_acts, self.obj_area) self.paste_canvas.image = renormalize.as_url( self.gw.render_object(viz_out, box=bounds))
def __init__(self, gw, mask_dir=None, size=256, num_canvases=9): super().__init__(style=dict(border="0", padding="0", display="inline-block", width="1000px", left="0", margin="0"), className='rwa') self.gw = gw self.size = size self.savedir = self.gw.cachedir if mask_dir is None else mask_dir self.original_model = copy.deepcopy(gw.model) self.request = {} self.imgnum_textbox = labwidget.Textbox( '0-%d' % (num_canvases - 1)).on('value', self.change_numbers) self.msg_out = labwidget.Div() self.loss_out = labwidget.Div() self.query_out = labwidget.Div() self.copy_canvas = paintwidget.PaintWidget(image='', width=self.size * 0.75, height=self.size * 0.75).on( 'mask', self.change_copy_mask) self.paste_canvas = paintwidget.PaintWidget( image='', width=self.size * 0.75, height=self.size * 0.75, opacity=0.0, oneshot=True, ).on('mask', self.change_paste_mask) self.object_out = labwidget.Div( style={ 'display': 'inline-block', 'vertical-align': 'top', 'width': '%spx' % size, 'height': '%spx' % size }) self.target_out = labwidget.Div( style={ 'display': 'inline-block', 'vertical-align': 'top', 'width': '%spx' % size, 'height': '%spx' % size }) self.context_out = labwidget.Div(style={ 'display': 'inline-block', 'vertical-align': 'top', 'text-align': 'left', 'width': '%spx' % ((size + 2) * 3 // 2), 'height': '%spx' % (size * 3 // 8 + 20), 'white-space': 'nowrap', 'overflow-x': 'scroll' }, className='ctx_tray') self.context_img_array = [] self.keytray_div = labwidget.Div(style={'display': 'none'}) self.keytray_menu = labwidget.Menu().on('selection', self.repaint_key_tray) self.keytray_removebtn = labwidget.Button('remove').on( 'click', self.keytray_remove) self.keytray_showbtn = labwidget.Button('show').on( 'click', self.keytray_show) self.keytray_querybtn = labwidget.Button('query').on( 'click', self.keytray_query) self.keytray_zerobtn = labwidget.Button('zero').on( 'click', self.keytray_zero) self.keytray_canvas = paintwidget.PaintWidget(width=self.size, height=self.size, vanishing=False, disabled=True) self.keytray_div.show([[ self.keytray_menu, self.keytray_removebtn, self.keytray_showbtn, self.keytray_zerobtn, self.keytray_querybtn, ], [self.keytray_canvas]]) inline = dict(display='inline') self.query_btn = labwidget.Button('Match Sel', style=inline).on( 'click', self.query) self.context_querybtn = labwidget.Button('Search', style=inline).on( 'click', self.keytray_query) self.highlight_btn = labwidget.Button('Show Context Matches', style=inline).on( 'click', self.toggle_highlight) self.original_btn = labwidget.Button('Toggle Original', style=inline).on( 'click', self.toggle_original) self.object_btn = labwidget.Button('Copy', style=inline).on( 'click', self.pick_object) self.key_btn = labwidget.Button('Add to Context', style=inline).on( 'click', self.key_add) self.paste_btn = labwidget.Button('Paste', style=inline).on( 'click', self.paste) self.snap_btn = labwidget.Button('Snap').on('click', self.snapshot_images) self.brushsize_textbox = labwidget.Textbox( 10, desc='brush: ', size=3).on('value', self.change_brushsize) self.rank_textbox = labwidget.Textbox('1', desc='rank: ', size=4, style=inline) self.paste_niter_textbox = labwidget.Textbox('2001', desc='paste niter: ', size=8) self.paste_piter_textbox = labwidget.Textbox('10', desc='proj every: ', size=4) self.paste_lr_textbox = labwidget.Textbox('0.05', desc='paste lr: ', size=8) self.erase_btn = labwidget.Button('Erase').on('click', self.exec_erase) self.exec_btn = labwidget.Button( 'Execute Change', style=dict(display='inline', background='darkgreen')).on('click', self.exec_request) self.overfit_btn = labwidget.Button('Overfit').on( 'click', self.exec_overfit) self.revert_btn = labwidget.Button('Revert', style=inline).on( 'click', self.revert) self.saved_list = labwidget.Datalist(choices=self.saved_names(), style=inline) self.load_btn = labwidget.Button('Load', style=inline).on( 'click', self.tryload) self.save_btn = labwidget.Button('Save', style=inline).on('click', self.save) self.sel = list(range(num_canvases)) self.overwriting = True self.obj_acts = None self.query_key = None self.query_vis = False self.show_original = False self.query_rq = None self.query_key_valid = True self.clipped_activations = None self.canvas_array = [] self.snap_image_array = [] for i in range(num_canvases): self.canvas_array.append( paintwidget.PaintWidget( image=renormalize.as_url(self.gw.render_image(i)), # width=self.size * 3 // 4, height=self.size * 3 // 4 width=self.size, height=self.size).on('mask', self.change_mask)) self.canvas_array[-1].index = i self.snap_image_array.append( labwidget.Image( style={ 'margin-top': 0, 'max-width': '%dpx' % self.size, 'max-height': '%dpx' % self.size })) self.snap_image_array[-1].index = i self.current_mask_item = None