コード例 #1
0
 def old_repaint_key_tray(self):
     if 'key' not in self.request or len(self.request['key']) == 0:
         self.keytray_div.style = {'display': 'none'}
         return
     keymasks = OrderedDict([
         (imgnum, mask) for imgnum, mask in self.request['key']])
     self.keytray_menu.choices = list(keymasks.keys())
     self.keytray_div.style = {'display': 'block'}
     if self.keytray_menu.selection is None or (
             int(self.keytray_menu.selection) not in keymasks):
         sel = int(self.keytray_menu.choices[-1])
         self.keytray_menu.selection = sel
     else:
         sel = int(self.keytray_menu.selection)
     self.keytray_canvas.image = renormalize.as_url(
                 self.gw.render_image(sel, None))
     self.keytray_canvas.mask = keymasks[sel]
コード例 #2
0
    def repaint_canvas_array(self):
        level = (self.query_rq.quantiles(0.999)[0]
                if (self.query_vis and self.query_rq) else None)
        if self.show_original:
            saved_state_dict = copy.deepcopy(self.gw.model.state_dict())
            with torch.no_grad():
                self.gw.model.load_state_dict(self.original_model.state_dict())
        images = self.gw.render_image_batch(self.sel,
                self.query_key if self.query_vis else None, level,
                        border_color=[255,255,255])
        if self.show_original:
            with torch.no_grad():
                self.gw.model.load_state_dict(saved_state_dict)

        for i, canvas in enumerate(self.canvas_array):
            canvas.mask = ''
            size = (canvas.height, canvas.width) if canvas.height else None
            if i < len(self.sel):
                canvas.image = renormalize.as_url(images[i], size=size)
            else:
                canvas.image = ''
コード例 #3
0
 def exec_object(self, obj_acts=None, obj_area=None, obj_output=None,
         bounds=None):
     if obj_acts is None:
         imgnum, mask = self.request['object']
         obj_acts, obj_output, obj_area, bounds = (
                 self.gw.object_from_selection(imgnum, mask))
     if obj_area is None:
         obj_area = torch.ones(obj_acts.shape[2:], device=obj_acts.device)
         mask = None
     self.obj_acts, self.obj_area = obj_acts, obj_area
     highest_channel = obj_acts.max(3)[0].max(2)[0].max(1)[1].item()
     cropped_out = self.gw.merge_target_output(obj_output, obj_acts, bounds)
     imgout = self.gw.render_object(cropped_out, obj_area)
     iv = imgviz.ImageVisualizer((imgout.height, imgout.width))
     # self.object_out.show(self.request_mask(thickness=3))
     self.copy_canvas.image = renormalize.as_url(
             self.request_mask(thickness=3))
     self.copy_canvas.mask = mask
     # self.object_out.show(
     #         ['value (from %d):' % imgnum, [imgout],
     #         [iv.heatmap(obj_acts[0, highest_channel], mode='nearest')]])
     self.show_msg('picked object')
コード例 #4
0
 def exec_paste(self):
     imgnum, mask = self.request['paste']
     goal_in, goal_out, viz_out, bounds = self.gw.paste_from_selection(
                 imgnum, mask, self.obj_acts, self.obj_area)
     self.paste_canvas.image = renormalize.as_url(
         self.gw.render_object(viz_out, box=bounds))
コード例 #5
0
 def __init__(self, gw, mask_dir=None, size=256, num_canvases=9):
     super().__init__(style=dict(border="3px solid gray", padding="8px",
         display="inline-block"))
     self.gw = gw
     self.size = size
     self.savedir = self.gw.cachedir if mask_dir is None else mask_dir
     self.original_model = copy.deepcopy(gw.model)
     self.request = {}
     self.imgnum_textbox = labwidget.Textbox('0-%d' % (num_canvases-1)
             ).on('value', self.change_numbers)
     self.msg_out = labwidget.Div()
     self.loss_out = labwidget.Div()
     self.query_out = labwidget.Div()
     self.copy_canvas = paintwidget.PaintWidget(
             image='', width=self.size, height=self.size
                  ).on('mask', self.change_copy_mask)
     self.paste_canvas = paintwidget.PaintWidget(
             image='', width=self.size, height=self.size,
             opacity=0.0, oneshot=True,
                  ).on('mask', self.change_paste_mask)
     self.object_out = labwidget.Div(
             style={'display': 'inline-block',
                 'vertical-align': 'top',
                 'width': '%spx' % size,
                 'height': '%spx' % size})
     self.target_out = labwidget.Div(
             style={'display': 'inline-block',
                 'vertical-align': 'top',
                 'width': '%spx' % size,
                 'height': '%spx' % size})
     self.context_out = labwidget.Div(
             style={'display': 'inline-block',
                 'vertical-align': 'top',
                 'width': '%spx' % (size + 24),
                 'height': '%spx' % (size + 4),
                 'overflow-y': 'scroll'})
     self.context_img_array = []
     self.keytray_div = labwidget.Div(style={'display': 'none'})
     self.keytray_menu = labwidget.Menu(
             ).on('selection', self.repaint_key_tray)
     self.keytray_removebtn = labwidget.Button('remove'
             ).on('click', self.keytray_remove)
     self.keytray_showbtn = labwidget.Button('show'
             ).on('click', self.keytray_show)
     self.keytray_querybtn = labwidget.Button('query'
             ).on('click', self.keytray_query)
     self.keytray_zerobtn = labwidget.Button('zero'
             ).on('click', self.keytray_zero)
     self.keytray_canvas = paintwidget.PaintWidget(
             width=self.size, height=self.size,
             vanishing=False, disabled=True)
     self.keytray_div.show([
         [
             self.keytray_menu,
             self.keytray_removebtn,
             self.keytray_showbtn,
             self.keytray_zerobtn,
             self.keytray_querybtn,
         ],
         [self.keytray_canvas]])
     inline = dict(display='inline')
     self.query_btn = labwidget.Button('match selection', style=inline
             ).on('click', self.query)
     self.context_querybtn = labwidget.Button('match context', style=inline
             ).on('click', self.keytray_query)
     self.highlight_btn = labwidget.Button('highlight', style=inline
             ).on('click', self.toggle_highlight)
     self.original_btn = labwidget.Button('toggle original', style=inline
             ).on('click', self.toggle_original)
     self.object_btn = labwidget.Button('copy', style=inline
             ).on('click', self.pick_object)
     self.key_btn = labwidget.Button('context', style=inline
             ).on('click', self.key_add)
     self.paste_btn = labwidget.Button('paste', style=inline
             ).on('click', self.paste)
     self.snap_btn = labwidget.Button('snap'
             ).on('click', self.snapshot_images)
     self.brushsize_textbox = labwidget.Textbox(10, desc='brush: ', size=3
             ).on('value', self.change_brushsize)
     self.rank_textbox = labwidget.Textbox(
             '1', desc='rank: ', size=4, style=inline)
     self.paste_niter_textbox = labwidget.Textbox(
             '2001', desc='paste niter: ', size=8)
     self.paste_piter_textbox = labwidget.Textbox(
             '10', desc='proj every: ', size=4)
     self.paste_lr_textbox = labwidget.Textbox(
             '0.05', desc='paste lr: ', size=8)
     self.erase_btn = labwidget.Button('erase').on('click', self.exec_erase)
     self.exec_btn = labwidget.Button('execute change', style=inline
             ).on('click', self.exec_request)
     self.overfit_btn = labwidget.Button('overfit').on(
             'click', self.exec_overfit)
     self.revert_btn = labwidget.Button('revert', style=inline
             ).on('click', self.revert)
     self.saved_list = labwidget.Datalist(choices=self.saved_names(),
             style=inline)
     self.load_btn = labwidget.Button('load', style=inline
             ).on('click', self.tryload)
     self.save_btn = labwidget.Button('save', style=inline
             ).on('click', self.save)
     self.sel = list(range(num_canvases))
     self.overwriting = True
     self.obj_acts = None
     self.query_key = None
     self.query_vis = False
     self.show_original = False
     self.query_rq = None
     self.clipped_activations = None
     self.canvas_array = []
     self.snap_image_array = []
     for i in range(num_canvases):
         self.canvas_array.append(paintwidget.PaintWidget(
             image=renormalize.as_url(
                 self.gw.render_image(i)),
                  width=self.size, height=self.size
                  ).on('mask', self.change_mask))
         self.canvas_array[-1].index = i
         self.snap_image_array.append(
                 labwidget.Image(style={'margin-top':0,
                 'max-width':'%dpx' % self.size,
                 'max-height':'%dpx' % self.size}))
         self.snap_image_array[-1].index = i
     self.current_mask_item = None