예제 #1
0
    def repaint_canvas_array(self):
        level = (self.query_rq.quantiles(0.999)[0]
                 if (self.query_vis and self.query_rq) else None)
        if self.show_original:
            saved_state_dict = copy.deepcopy(self.gw.model.state_dict())
            with torch.no_grad():
                self.gw.model.load_state_dict(self.original_model.state_dict())
        images = self.gw.render_image_batch(self.sel,
                                            self.query_key if self.query_vis else None, level,
                                            border_color=[255, 255, 255])
        #customize change  
        f = open(self.savedir+'edited.pkl', 'wb')
        pickle.dump(images, f)
        f.close()
        
        if self.show_original:
            with torch.no_grad():
                self.gw.model.load_state_dict(saved_state_dict)

        for i, canvas in enumerate(self.canvas_array):
            canvas.mask = ''
            size = (canvas.height, canvas.width) if canvas.height else None
            if i < len(self.sel):
                canvas.image = renormalize.as_url(images[i], size=size)
            else:
                canvas.image = ''
예제 #2
0
 def old_repaint_key_tray(self):
     if 'key' not in self.request or len(self.request['key']) == 0:
         self.keytray_div.style = {'display': 'none'}
         return
     keymasks = OrderedDict([(imgnum, mask)
                             for imgnum, mask in self.request['key']])
     self.keytray_menu.choices = list(keymasks.keys())
     self.keytray_div.style = {'display': 'block'}
     if self.keytray_menu.selection is None or (int(
             self.keytray_menu.selection) not in keymasks):
         sel = int(self.keytray_menu.choices[-1])
         self.keytray_menu.selection = sel
     else:
         sel = int(self.keytray_menu.selection)
     self.keytray_canvas.image = renormalize.as_url(
         self.gw.render_image(sel, None))
     self.keytray_canvas.mask = keymasks[sel]
예제 #3
0
 def exec_object(self, obj_acts=None, obj_area=None, obj_output=None,
                 bounds=None):
     if obj_acts is None:
         imgnum, mask = self.request['object']
         obj_acts, obj_output, obj_area, bounds = (
             self.gw.object_from_selection(imgnum, mask))
     if obj_area is None:
         obj_area = torch.ones(obj_acts.shape[2:], device=obj_acts.device)
         mask = None
     self.obj_acts, self.obj_area = obj_acts, obj_area
     cropped_out = self.gw.merge_target_output(obj_output, obj_acts, bounds)
     imgout = self.gw.render_object(cropped_out, obj_area)
     imgviz.ImageVisualizer((imgout.height, imgout.width))
     self.copy_canvas.image = renormalize.as_url(
         self.request_mask(thickness=3))
     self.copy_canvas.mask = mask
     self.show_msg('picked object')
예제 #4
0
 def exec_paste(self):
     imgnum, mask = self.request['paste']
     goal_in, goal_out, viz_out, bounds = self.gw.paste_from_selection(
         imgnum, mask, self.obj_acts, self.obj_area)
     self.paste_canvas.image = renormalize.as_url(
         self.gw.render_object(viz_out, box=bounds))
예제 #5
0
 def __init__(self, gw, mask_dir=None, size=256, num_canvases=9):
     super().__init__(style=dict(border="0",
                                 padding="0",
                                 display="inline-block",
                                 width="1000px",
                                 left="0",
                                 margin="0"),
                      className='rwa')
     self.gw = gw
     self.size = size
     self.savedir = self.gw.cachedir if mask_dir is None else mask_dir
     self.original_model = copy.deepcopy(gw.model)
     self.request = {}
     self.imgnum_textbox = labwidget.Textbox(
         '0-%d' % (num_canvases - 1)).on('value', self.change_numbers)
     self.msg_out = labwidget.Div()
     self.loss_out = labwidget.Div()
     self.query_out = labwidget.Div()
     self.copy_canvas = paintwidget.PaintWidget(image='',
                                                width=self.size * 0.75,
                                                height=self.size * 0.75).on(
                                                    'mask',
                                                    self.change_copy_mask)
     self.paste_canvas = paintwidget.PaintWidget(
         image='',
         width=self.size * 0.75,
         height=self.size * 0.75,
         opacity=0.0,
         oneshot=True,
     ).on('mask', self.change_paste_mask)
     self.object_out = labwidget.Div(
         style={
             'display': 'inline-block',
             'vertical-align': 'top',
             'width': '%spx' % size,
             'height': '%spx' % size
         })
     self.target_out = labwidget.Div(
         style={
             'display': 'inline-block',
             'vertical-align': 'top',
             'width': '%spx' % size,
             'height': '%spx' % size
         })
     self.context_out = labwidget.Div(style={
         'display':
         'inline-block',
         'vertical-align':
         'top',
         'text-align':
         'left',
         'width':
         '%spx' % ((size + 2) * 3 // 2),
         'height':
         '%spx' % (size * 3 // 8 + 20),
         'white-space':
         'nowrap',
         'overflow-x':
         'scroll'
     },
                                      className='ctx_tray')
     self.context_img_array = []
     self.keytray_div = labwidget.Div(style={'display': 'none'})
     self.keytray_menu = labwidget.Menu().on('selection',
                                             self.repaint_key_tray)
     self.keytray_removebtn = labwidget.Button('remove').on(
         'click', self.keytray_remove)
     self.keytray_showbtn = labwidget.Button('show').on(
         'click', self.keytray_show)
     self.keytray_querybtn = labwidget.Button('query').on(
         'click', self.keytray_query)
     self.keytray_zerobtn = labwidget.Button('zero').on(
         'click', self.keytray_zero)
     self.keytray_canvas = paintwidget.PaintWidget(width=self.size,
                                                   height=self.size,
                                                   vanishing=False,
                                                   disabled=True)
     self.keytray_div.show([[
         self.keytray_menu,
         self.keytray_removebtn,
         self.keytray_showbtn,
         self.keytray_zerobtn,
         self.keytray_querybtn,
     ], [self.keytray_canvas]])
     inline = dict(display='inline')
     self.query_btn = labwidget.Button('Match Sel', style=inline).on(
         'click', self.query)
     self.context_querybtn = labwidget.Button('Search', style=inline).on(
         'click', self.keytray_query)
     self.highlight_btn = labwidget.Button('Show Context Matches',
                                           style=inline).on(
                                               'click',
                                               self.toggle_highlight)
     self.original_btn = labwidget.Button('Toggle Original',
                                          style=inline).on(
                                              'click', self.toggle_original)
     self.object_btn = labwidget.Button('Copy', style=inline).on(
         'click', self.pick_object)
     self.key_btn = labwidget.Button('Add to Context', style=inline).on(
         'click', self.key_add)
     self.paste_btn = labwidget.Button('Paste', style=inline).on(
         'click', self.paste)
     self.snap_btn = labwidget.Button('Snap').on('click',
                                                 self.snapshot_images)
     self.brushsize_textbox = labwidget.Textbox(
         10, desc='brush: ', size=3).on('value', self.change_brushsize)
     self.rank_textbox = labwidget.Textbox('1',
                                           desc='rank: ',
                                           size=4,
                                           style=inline)
     self.paste_niter_textbox = labwidget.Textbox('2001',
                                                  desc='paste niter: ',
                                                  size=8)
     self.paste_piter_textbox = labwidget.Textbox('10',
                                                  desc='proj every: ',
                                                  size=4)
     self.paste_lr_textbox = labwidget.Textbox('0.05',
                                               desc='paste lr: ',
                                               size=8)
     self.erase_btn = labwidget.Button('Erase').on('click', self.exec_erase)
     self.exec_btn = labwidget.Button(
         'Execute Change',
         style=dict(display='inline',
                    background='darkgreen')).on('click', self.exec_request)
     self.overfit_btn = labwidget.Button('Overfit').on(
         'click', self.exec_overfit)
     self.revert_btn = labwidget.Button('Revert', style=inline).on(
         'click', self.revert)
     self.saved_list = labwidget.Datalist(choices=self.saved_names(),
                                          style=inline)
     self.load_btn = labwidget.Button('Load', style=inline).on(
         'click', self.tryload)
     self.save_btn = labwidget.Button('Save',
                                      style=inline).on('click', self.save)
     self.sel = list(range(num_canvases))
     self.overwriting = True
     self.obj_acts = None
     self.query_key = None
     self.query_vis = False
     self.show_original = False
     self.query_rq = None
     self.query_key_valid = True
     self.clipped_activations = None
     self.canvas_array = []
     self.snap_image_array = []
     for i in range(num_canvases):
         self.canvas_array.append(
             paintwidget.PaintWidget(
                 image=renormalize.as_url(self.gw.render_image(i)),
                 # width=self.size * 3 // 4, height=self.size * 3 // 4
                 width=self.size,
                 height=self.size).on('mask', self.change_mask))
         self.canvas_array[-1].index = i
         self.snap_image_array.append(
             labwidget.Image(
                 style={
                     'margin-top': 0,
                     'max-width': '%dpx' % self.size,
                     'max-height': '%dpx' % self.size
                 }))
         self.snap_image_array[-1].index = i
     self.current_mask_item = None