Beispiel #1
0
 def erase_from_selection(self, imgnum, mask, context_mask_pairs, rank):
     k_area = renormalize.from_url(mask, target='pt',
                                   size=self.k_shape[2:])[0]
     area = renormalize.from_url(mask, target='pt',
                                 size=self.v_shape[2:])[0]
     source_outputs = self.context_model(self.get_z(imgnum))
     source_acts = self.context_acts(source_outputs)
     unchanged_outputs = self.target_model(source_outputs)
     source_acts_without_units = source_acts.clone()
     d_units = self.normdissect_units(context_mask_pairs, rank)
     source_acts_without_units[:, d_units] = 0.0
     d_erased_in = self.merge_target_output(source_outputs,
                                            source_acts_without_units, None)
     d_erased_out = self.target_model(d_erased_in)
     target_acts = self.target_acts(d_erased_out)
     if self.tight_paste:
         source_bounds = positive_bounding_box(k_area)
         target_bounds = positive_bounding_box(area)
     else:
         source_bounds, target_bounds = None, None
     goal_in = self.merge_target_output(source_outputs, source_acts,
                                        source_bounds)
     goal_out = self.merge_target_output(unchanged_outputs, target_acts,
                                         target_bounds)
     return goal_in, goal_out
Beispiel #2
0
 def repaint_key_tray(self):
     if 'key' not in self.request:
         keymasks = {}
     else:
         keymasks = OrderedDict([(imgnum, mask)
                                 for imgnum, mask in self.request['key']])
     if len(self.context_img_array) < len(keymasks):
         while len(self.context_img_array) < len(keymasks):
             self.context_img_array.append(
                 labwidget.Image(
                     style=dict(maxWidth='%spx' % int(self.size * 3 // 8),
                                maxHeight='%spx' % int(self.size * 3 // 8),
                                border='1 px solid white')).on(
                                    'click', self.click_context_img))
         self.context_out.show(*[[imgw] for imgw in self.context_img_array])
     for i, (imgnum, mask) in enumerate(keymasks.items()):
         imgw = self.context_img_array[i]
         area = (renormalize.from_url(
             mask, target='pt', size=self.gw.x_shape[2:])[0] > 0.25)
         imgw.render(
             self.gw.render_image(imgnum,
                                  mask=area,
                                  thickness=0,
                                  outside_bright=1.0,
                                  inside_color=[255, 255, 255]))
         imgw.imgnum = imgnum
     for i in range(len(keymasks), len(self.context_img_array)):
         self.context_img_array[i].src = ''
         self.context_img_array[i].imgnum = None
Beispiel #3
0
 def rgb_from_selection(self, imgnum, mask):
     area = renormalize.from_url(mask, target='pt',
                                 size=self.x_shape[2:])[0]
     with torch.no_grad():
         x_output = self.model(self.get_z(imgnum))
     t, l, b, r = positive_bounding_box(area)
     rgb_clip = x_output[:, :, t:b, l:r]
     obj_area = area[t:b, l:r]
     return rgb_clip, x_output, obj_area, (t, l, b, r)
Beispiel #4
0
 def rgbpaste_from_selection(self, imgnum, mask, obj_rgb, obj_area):
     with torch.no_grad():
         area = renormalize.from_url(mask,
                                     target='pt',
                                     size=self.x_shape[2:])[0]
         source_z = self.get_z(imgnum)
         unchanged_rgb = self.model(source_z)
         changed_rgb, bounds = paste_clip_at_center(unchanged_rgb, obj_rgb,
                                                    centered_location(area),
                                                    obj_area)
     return source_z, changed_rgb, bounds
Beispiel #5
0
 def object_from_selection(self, imgnum, mask):
     area = renormalize.from_url(mask, target='pt',
                                 size=self.v_shape[2:])[0]
     with torch.no_grad():
         k_output = self.context_model(self.get_z(imgnum))
         v_output = self.target_model(k_output)
         v_acts = self.target_acts(v_output)
     t, l, b, r = positive_bounding_box(area)
     obj_activations = v_acts[:, :, t:b, l:r]
     obj_area = area[t:b, l:r]
     return obj_activations, v_output, obj_area, (t, l, b, r)
Beispiel #6
0
 def query_key_from_selection(self, imgnum, mask):
     area = renormalize.from_url(mask, target='pt',
                                 size=self.k_shape[2:])[0]
     with torch.no_grad():
         k_outs = self.context_model(self.get_z(imgnum))
         k_acts = self.context_acts(k_outs)
         mean = (k_acts[0] * area[None].to(self.device)).sum(2).sum(1) / (
             1e-10 + area.sum())
     k = self.covariance_adjusted_query_key(mean)
     k = k / (1e-10 + k.norm(2))
     return k
Beispiel #7
0
 def request_mask(self, field='object', index=None, **kwargs):
     # For generating high-resolution figures: directly visualize a mask.
     if field not in self.request:
         print(f'No {field} selected')
         return
     if field == 'key':
         if index >= len(self.request[field]):
             print(f'No {index}th entry in key')
             return
         imgnum, mask = self.request[field][index]
     else:
         imgnum, mask = self.request[field]
     area = (renormalize.from_url(
         mask, target='pt', size=self.gw.x_shape[2:])[0] > 0.25)
     imgout = self.gw.render_image(imgnum, mask=area, **kwargs)
     return imgout
Beispiel #8
0
 def normdissect_units(self, imgnum_mask_pairs, rank):
     with torch.no_grad():
         accumulated_obs = []
         for imgnum, mask in imgnum_mask_pairs:
             k_outs = self.context_model(self.get_z(imgnum))
             k_acts = self.context_acts(k_outs)
             area = renormalize.from_url(mask,
                                         target='pt',
                                         size=self.k_shape[2:])[0]
             accumulated_obs.append(
                 (k_acts.permute(0, 2, 3, 1).reshape(-1, k_acts.shape[1]),
                  area.view(-1)[:, None].to(k_acts.device)))
         all_obs = torch.cat([obs for obs, _ in accumulated_obs])
         all_weight = torch.cat([w for _, w in accumulated_obs])
         square_scale = self.square_scales_for_units().to(all_obs.device)
         all_logscore = all_obs.pow(2) / square_scale[None, :]
         mean_logscore = ((all_logscore * all_weight).sum(0) /
                          sum(all_weight))
         top_coords = mean_logscore.sort(descending=True)[1][:rank]
         return top_coords
Beispiel #9
0
 def paste_from_selection(self, imgnum, mask, obj_acts, obj_area):
     area = renormalize.from_url(mask, target='pt',
                                 size=self.v_shape[2:])[0]
     source_outputs = self.context_model(self.get_z(imgnum))
     source_acts = self.context_acts(source_outputs)
     unchanged_outputs = self.target_model(source_outputs)
     unchanged_acts = self.target_acts(unchanged_outputs)
     target_acts, bounds = paste_clip_at_center(
         unchanged_acts, obj_acts, centered_location(area),
         obj_area if self.alpha_area else None)
     full_target_acts = target_acts
     if self.tight_paste:
         source_acts, target_acts, source_bounds, target_bounds = (
             crop_clip_to_bounds(source_acts, target_acts, bounds))
     else:
         source_bounds, target_bounds = None, None
     goal_in = self.merge_target_output(source_outputs, source_acts,
                                        source_bounds)
     goal_out = self.merge_target_output(unchanged_outputs, target_acts,
                                         target_bounds)
     viz_out = self.merge_target_output(unchanged_outputs, full_target_acts,
                                        None)
     return goal_in, goal_out, viz_out, bounds
Beispiel #10
0
 def is_empty_mask(self, mask):
     area = renormalize.from_url(mask, target='pt')[0]
     return area.sum() == 0.0
Beispiel #11
0
 def multi_key_from_selection(self,
                              imgnum_mask_pairs,
                              rank=1,
                              key_method=None):
     global all_obs, all_weight, all_CinvK, all_kCinvK, e_val, e_vec, kbasis, row_dirs, q
     if key_method is None:
         key_method = self.key_method
     with torch.no_grad():
         if key_method in ['zca']:
             accumulated_obs = []
             for imgnum, mask in imgnum_mask_pairs:
                 k_outs = self.context_model(self.get_z(imgnum))
                 k_acts = self.context_acts(k_outs)
                 area = renormalize.from_url(mask,
                                             target='pt',
                                             size=self.k_shape[2:])[0]
                 accumulated_obs.append(
                     (k_acts.permute(0, 2, 3,
                                     1).reshape(-1,
                                                k_acts.shape[1]), k_outs,
                      area.view(-1)[:, None].to(k_acts.device)))
             with warnings.catch_warnings():
                 # nonzero() prints a warning about a new signature we can ignore.
                 warnings.simplefilter('ignore', UserWarning)
                 all_obs = torch.cat([
                     obs[(w > 0).nonzero()[:, 0], :]
                     for obs, _, w in accumulated_obs
                 ])
                 all_weight = torch.cat(
                     [w[w > 0] for _, _, w in accumulated_obs])
                 all_zca_k = torch.cat([
                     (w * self.zca_whitened_query_key(obs))[(
                         w > 0).nonzero()[:, 0], :]
                     for obs, outs, w in accumulated_obs
                 ])
             # all_zca_k is already transposed
             _, _, q = all_zca_k.svd(compute_uv=True)
             # Get the top rank e_vecs in whitened space
             top_e_vec = q[:, :rank]
             # Transform them into rowspace. (Same as multiplying
             # by ZCA matrix a 2nd time.)
             row_dirs = self.zca_whitened_query_key(top_e_vec.t())
             just_avg = (all_zca_k).sum(0)
             # Orthogonalize row_dirs
             q, r = torch.qr(row_dirs.permute(1, 0))
             # Flip the first eigenvec to agree with avg direction.
             signs = (q * just_avg[:, None]).sum(0).sign()
             q = q * signs[None, :]
             return q.permute(1, 0)
         if key_method == 'gandissect':
             # Unit-wise-keys select D using a gandissect-like rule.
             # We score a unit by how unusual the visible values are.
             # i.e., log probability = weighted-sum log probabilities
             # Here we use explicitly counted quantiles to estimate probs.
             accumulated_obs = []
             for imgnum, mask in imgnum_mask_pairs:
                 k_outs = self.context_model(self.get_z(imgnum))
                 k_acts = self.context_acts(k_outs)
                 area = renormalize.from_url(mask,
                                             target='pt',
                                             size=self.k_shape[2:])[0]
                 accumulated_obs.append(
                     (k_acts.permute(0, 2, 3,
                                     1).reshape(-1, k_acts.shape[1]),
                      area.view(-1)[:, None].to(k_acts.device)))
             all_obs = torch.cat([obs for obs, _ in accumulated_obs])
             all_weight = torch.cat([w for _, w in accumulated_obs])
             rq = self.quantiles_for_units()
             all_logscore = -torch.log(
                 1.0 - rq.normalize(all_obs.permute(1, 0))).permute(1, 0)
             mean_logscore = ((all_logscore * all_weight).sum(0) /
                              sum(all_weight))
             top_coords = mean_logscore.sort(descending=True)[1][:rank]
             result = torch.zeros(rank,
                                  all_obs.shape[1],
                                  device=all_obs.device)
             result[torch.arange(rank), top_coords] = 1.0
             # print('top_coords', top_coords.tolist())
             return result
         # Old SVD method
         assert key_method in ['svd', 'mean']
         accumulated_k = []
         for imgnum, mask in imgnum_mask_pairs:
             k_outs = self.context_model(self.get_z(imgnum))
             k_acts = self.context_acts(k_outs)
             area = renormalize.from_url(mask,
                                         target='pt',
                                         size=self.k_shape[2:])[0]
             weighted_k = (k_acts[0] * area[None].to(self.device)).permute(
                 1, 2, 0).view(-1, k_acts.shape[1])
             nonzero_k = weighted_k[weighted_k.norm(2, dim=1) > 0]
             accumulated_k.append((nonzero_k, k_outs))
         all_k = torch.cat([
             self.covariance_adjusted_key(nonzero_k, k_outs)
             for nonzero_k, k_outs in accumulated_k
         ])
         just_avg = all_k.mean(0)
         if key_method == 'mean':
             assert rank == 1
             return just_avg[None, :] / just_avg.norm()
         u, s, v = torch.svd(all_k.permute(1, 0), some=False)
         if (just_avg * u[:, 0]).sum() < 0:
             # Flip the first singular vectors to agree with avg direction
             u[:, 0] = -u[:, 0]
             v[:, 0] = -v[:, 0]
         assert u.shape[1] >= rank
         return u.permute(1, 0)[:rank]