def plot_example(input, saliency, method, category_id, show_plot=False, save_path=None): """Plot an example. Args: input (:class:`torch.Tensor`): 4D tensor containing input images. saliency (:class:`torch.Tensor`): 4D tensor containing saliency maps. method (str): name of saliency method. category_id (int): ID of ImageNet category. show_plot (bool, optional): If True, show plot. Default: ``False``. save_path (str, optional): Path to save figure to. Default: ``None``. """ from torchray.utils import imsc from torchray.benchmark.datasets import IMAGENET_CLASSES if isinstance(category_id, int): category_id = [category_id] batch_size = len(input) plt.clf() for i in range(batch_size): class_i = category_id[i % len(category_id)] plt.subplot(batch_size, 2, 1 + 2 * i) imsc(input[i]) plt.title('input image', fontsize=8) plt.subplot(batch_size, 2, 2 + 2 * i) imsc(saliency[i], interpolation='none') plt.title('{} for category {} ({})'.format(method, IMAGENET_CLASSES[class_i], class_i), fontsize=8) # Save figure if path is specified. if save_path: save_dir = os.path.dirname(os.path.abspath(save_path)) # Create directory if necessary. if not os.path.exists(save_dir): os.makedirs(save_dir) ext = os.path.splitext(save_path)[1].strip('.') plt.savefig(save_path, format=ext, bbox_inches='tight') # Show plot if desired. if show_plot: plt.show()
def image2tensor(image_path): # convert image to torch tensor with shape (1 * 3 * 224 * 224) img = Image.open(image_path) p = transforms.Compose([transforms.Scale((224,224))]) img,i = imsc(p(img),quiet=False) return torch.reshape(img, (1,3,224,224))
def torchray_multimodal_explain(image_path,text): # image_path = "static\\" + image_path model = MMBT.from_pretrained("mmbt.hateful_memes.images") model = model.to(torch.device( "cuda:0" if torch.cuda.is_available() else "cpu")) image_tensor = image2tensor(image_path) image_tensor = image_tensor.to((torch.device( "cuda:0" if torch.cuda.is_available() else "cpu"))) mask_, hist_, output_tensor, summary, conclusion = multi_extremal_perturbation( model, image_tensor, image_path, text, 0, reward_func=contrastive_reward, debug=True, areas=[0.12]) # summary is a higher level explanation in terms of sentence # conclusion is a list that contains words and their weights # output_tensor is the masked image Image = transforms.ToPILImage()(imsc(image_tensor[0],quiet=False)[0]).convert("RGB") Image.save("torchray.png") print(summary) return conclusion
def _ep(m, x, i): # Extremal perturbation backprop. masks_1, _ = extremal_perturbation( m, x, i, reward_func=contrastive_reward, debug=True, areas=[area], ) masks_1 = imsc(masks_1[0], quiet=True)[0][None] return masks_1
def classify(self,image,text_input, image_tensor = None): ''' Args: image_path: directory of input image text_input : the text input Str image_tensor : the image torch.tensor with size (1,3,224,224) Returns : label of model prediction and the corresponding confidence ''' scoreFlag = False if image_tensor != None: scoreFlag = True logits = self.onnx_model_forward(image_tensor,text_input) else: p = transforms.Compose([transforms.Scale((224,224))]) image,i = imsc(p(image),quiet=True) image_tensor = torch.reshape(image, (1,3,224,224)) logits = self.onnx_model_forward(image_tensor,text_input) if list(torch.tensor(logits).size()) != [1, 2]: if self.defaultmodel == None: self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images") self.defaultmodel.to(self.device) logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to(self.device), 0)) scores = nn.functional.softmax(torch.tensor(logits), dim=1) if scoreFlag == True: return scores confidence, label = torch.max(scores, dim=1) return {"label": label.item(), "confidence": confidence.item()}
def hierarchical_perturbation(model, input, target, vis=False, interp_mode='nearest', resize=None, batch_size=32, perturbation_type='fade', num_cells=4): with torch.no_grad(): # Get device of input (i.e., GPU). dev = input.device if dev == 'cpu': batch_size = 1 bn, channels, input_y_dim, input_x_dim = input.shape dim = min(input_x_dim, input_y_dim) total_masks = 0 depth = 0 max_depth = int(np.log2(dim / num_cells)) saliency = torch.zeros((1, 1, input_y_dim, input_x_dim), device=dev) depth_saliencies = torch.zeros( (max_depth + 1, 1, 1, input_y_dim, input_x_dim), device=dev) max_batch = batch_size output = F.softmax(model(input), dim=1)[:, target] if perturbation_type == 'blur': pre_b_image = blur(input.clone().cpu()).to(dev) while depth < max_depth: masks_list = [] b_list = [] num_cells *= 2 depth += 1 threshold = torch.min(saliency) + ( (torch.max(saliency) - torch.min(saliency)) / 2) print('Depth: {}, {} x {} Cell'.format(depth, input_y_dim // num_cells, input_x_dim // num_cells)) print('Threshold: {:.1f}'.format(threshold)) print('Range {:.1f} to {:.1f}'.format(saliency.min(), saliency.max())) y_ixs = range(-1, num_cells) x_ixs = range(-1, num_cells) x_cell_dim = input_x_dim // num_cells y_cell_dim = input_y_dim // num_cells pos_masks = 0 for x in x_ixs: for y in y_ixs: x1, y1 = max(0, x), max(0, y) x2, y2 = min(x + 2, num_cells), min(y + 2, num_cells) pos_masks += 1 mask = torch.zeros((1, 1, num_cells, num_cells), device=dev) mask[:, :, y1:y2, x1:x2] = 1.0 local_saliency = F.interpolate(mask, (input_y_dim, input_x_dim), mode=interp_mode) * saliency if depth > 1: local_saliency = torch.max(local_saliency) else: local_saliency = 0 # If salience of region is greater than the average, generate higher resolution mask if local_saliency >= threshold: masks_list.append(abs(mask - 1)) if perturbation_type == 'blur': b_image = input.clone() b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim, x1 * x_cell_dim:x2 * x_cell_dim] = pre_b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim, x1 * x_cell_dim:x2 * x_cell_dim] b_list.append(b_image) if perturbation_type == 'mean': b_image = input.clone() mean = torch.mean( b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim, x1 * x_cell_dim:x2 * x_cell_dim], axis=(-1, -2), keepdims=True) b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim, x1 * x_cell_dim:x2 * x_cell_dim] = mean b_list.append(b_image) num_masks = len(masks_list) print('Selected {}/{} masks at depth {}'.format( num_masks, pos_masks, depth)) if num_masks == 0: depth -= 1 break total_masks += num_masks while len(masks_list) > 0: m_ix = min(len(masks_list), max_batch) if perturbation_type != 'fade': b_imgs = torch.cat(b_list[:m_ix]) del b_list[:m_ix] masks = torch.cat(masks_list[:m_ix]) del masks_list[:m_ix] # resize low-res masks to input size masks = F.interpolate(masks, (input_y_dim, input_x_dim), mode=interp_mode) if perturbation_type == 'fade': perturbed_outputs = torch.relu( output - F.softmax(model(input * masks), dim=1)[:, target]) else: perturbed_outputs = torch.relu( output - F.softmax(model(b_imgs), dim=1)[:, target]) sal = perturbed_outputs * torch.abs(masks.transpose(0, 1) - 1) saliency += torch.sum(sal, dim=(0, 1)) depth_saliencies[depth] += torch.sum(sal, dim=(0, 1)) if vis: clear_output(wait=True) print('Saving image...') plt.figure(figsize=(8, 4)) plt.subplot(1, 2, 1) plt.title( 'Depth: {}, {} x {} Mask\nThreshold: {:.1f}'.format( depth, num_cells, num_cells, threshold)) if perturbation_type == 'fade': imsc((input * masks)[0]) else: imsc(b_imgs[0]) plt.subplot(1, 2, 2) imsc(saliency[0]) plt.show() plt.savefig('data/attribution_benchmarks/preview') plt.close() print('Used {} masks in total.'.format(total_masks)) if resize is not None: saliency = F.interpolate(saliency, (resize[1], resize[0]), mode=interp_mode) return saliency, depth_saliencies, total_masks
def decnn(m, x, i): saliency = deconvnet(m, x, i) saliency = imsc(saliency[0], quiet=True)[0][None] return saliency
def excite_l1(m, x, i): saliency = excitation_backprop( m, x, i, saliency_layer=model.encoder_q.layer1) saliency = imsc(saliency[0], quiet=True)[0][None] return saliency
def do_saliency(x, k, one_vs_all=False): head = torch.nn.Linear(k.shape[-1], k.shape[-2], bias=False) head.weight.data[:] = k.data[:] head = head.cuda() contrast_model = torch.nn.Sequential(model.encoder_q, L2Normalize(), head, torch.nn.Softmax(dim=-1)) x = x.cuda() def colorize(saliency): resized = cv2.resize(saliency.permute(0, 2, 3, 1)[0].detach().cpu().numpy(), (x.shape[-2], x.shape[-1]), interpolation=cv2.INTER_LANCZOS4) * 255 # import pdb; pdb.set_trace() #* 255. # import pdb; pdb.set_trace() return color(resized)[..., :3] # return np.stack([resized]*3, axis=-1) # return np.stack([resized]*3, dim=-1) def _saliency(x, func): s = [] for i in range(x.shape[0]): saliency = func(contrast_model, x[i][None].cuda(), i) # import pdb; pdb.set_trace() saliency = colorize(saliency) s.append(saliency) return np.stack(s).transpose(0, -1, 1, 2) def gcam(m, x, i): g = grad_cam(m, x, i, saliency_layer=model.encoder_q.layer4) return g def gcam_l3(m, x, i): return grad_cam(m, x, i, saliency_layer=model.encoder_q.layer3) def excite_c1(m, x, i): saliency = excitation_backprop( m, x, i, saliency_layer=model.encoder_q.conv1) saliency = imsc(saliency[0], quiet=True)[0][None] return saliency def excite_l1(m, x, i): saliency = excitation_backprop( m, x, i, saliency_layer=model.encoder_q.layer1) saliency = imsc(saliency[0], quiet=True)[0][None] return saliency def ceb(m, x, i): # Contrastive excitation backprop. return contrastive_excitation_backprop( m, x, i, saliency_layer=model.encoder_q.layer2[-1], contrast_layer=model.encoder_q.layer4[-1], classifier_layer=model.encoder_q.fc[-1]) # import excitationbp as eb def ep(area): def _ep(m, x, i): # Extremal perturbation backprop. masks_1, _ = extremal_perturbation( m, x, i, reward_func=contrastive_reward, debug=True, areas=[area], ) masks_1 = imsc(masks_1[0], quiet=True)[0][None] return masks_1 return _ep def decnn(m, x, i): saliency = deconvnet(m, x, i) saliency = imsc(saliency[0], quiet=True)[0][None] return saliency if one_vs_all: ''' compare first x to all k ''' import time assert x.shape[0] == 1, 'expected just one instance' t0 = time.time() rise_saliency = rise(contrast_model, x).transpose(0, 1) print('rise took ************', time.time() - t0) rise_saliency = torch.stack( [imsc(rs, quiet=True)[0] for rs in rise_saliency]) # import pdb; pdb.set_trace() # rise_saliency = np.stack([colorize(rs[None]) for rs in rise_saliency]) rise_saliency = np.stack([rise_saliency.detach().cpu().numpy()] * 3, axis=-1)[:, 0] # import pdb; pdb.set_trace() rise_saliency = rise_saliency.transpose(0, -1, 1, 2) # gcam_saliency = return dict( # rise=rise_saliency, # grad_cam=_saliency(torch.cat([x]*k.shape[0]), gcam), ) else: out = dict( # grad_cam=_saliency(x, gcam), # grad_cam_l3=_saliency(x, gcam_l3), contrastive_excitation_backprop=_saliency(x, ceb), # excite_c1=_saliency(x, excite_c1), # excite_l1=_saliency(x, excite_l1), # deconvnet=_saliency(x, decnn), # rise=rise_saliency ) for k_ep in [0.05]: #, 0.05, 0.12]: out['extremal_perturbation_%s' % k_ep] = _saliency(x, ep(k_ep)) return out
def __next__(self): self._lazy_init() x, y = next(self.data_iterator) torch.manual_seed(self.seed) if self.log: from torchray.benchmark.logging import mongo_load, mongo_save, \ data_from_mongo, data_to_mongo try: assert len(x) == 1 x = x.to(self.device) class_ids = self.data.as_class_ids(y[0]) image_size = self.data.as_image_size(y[0]) results = {'pointing': {}, 'pointing_difficult': {}} info = {} rise_saliency = None for class_id in class_ids: # Try to recover this result from the log. if self.log > 0: image_name = self.data.as_image_name(y[0]) data = mongo_load( self.db, self.experiment.name, f"{image_name}-{class_id}", ) if data is not None: data = data_from_mongo(data) results['pointing'][class_id] = data['pointing'] results['pointing_difficult'][class_id] = data[ 'pointing_difficult'] if self.debug: print(f'{image_name}-{class_id} loaded from log') continue # TODO(av): should now be obsolete if x.grad is not None: x.grad.data.zero_() if self.experiment.method == "center": w, h = image_size point = torch.tensor([[w / 2, h / 2]]) elif self.experiment.method == "gradient": saliency = gradient( self.model, x, class_id, resize=image_size, smooth=0.02, get_backward_gradient=get_pointing_gradient) point = _saliency_to_point(saliency) info['saliency'] = saliency elif self.experiment.method == "deconvnet": saliency = deconvnet( self.model, x, class_id, resize=image_size, smooth=0.02, get_backward_gradient=get_pointing_gradient) point = _saliency_to_point(saliency) info['saliency'] = saliency elif self.experiment.method == "guided_backprop": saliency = guided_backprop( self.model, x, class_id, resize=image_size, smooth=0.02, get_backward_gradient=get_pointing_gradient) point = _saliency_to_point(saliency) info['saliency'] = saliency elif self.experiment.method == "grad_cam": saliency = grad_cam( self.model, x, class_id, saliency_layer=self.gradcam_layer, resize=image_size, get_backward_gradient=get_pointing_gradient) point = _saliency_to_point(saliency) info['saliency'] = saliency elif self.experiment.method == "excitation_backprop": saliency = excitation_backprop( self.model, x, class_id, self.saliency_layer, resize=image_size, get_backward_gradient=get_pointing_gradient) point = _saliency_to_point(saliency) info['saliency'] = saliency elif self.experiment.method == "contrastive_excitation_backprop": saliency = contrastive_excitation_backprop( self.model, x, class_id, saliency_layer=self.saliency_layer, contrast_layer=self.contrast_layer, resize=image_size, get_backward_gradient=get_pointing_gradient) point = _saliency_to_point(saliency) info['saliency'] = saliency elif self.experiment.method == "rise": # For RISE, compute saliency map for all classes. if rise_saliency is None: rise_saliency = rise(self.model, x, resize=image_size, seed=self.seed) saliency = rise_saliency[:, class_id, :, :].unsqueeze(1) point = _saliency_to_point(saliency) info['saliency'] = saliency elif self.experiment.method == "extremal_perturbation": if self.experiment.dataset == 'voc_2007': areas = [0.025, 0.05, 0.1, 0.2] else: areas = [0.018, 0.025, 0.05, 0.1] if self.experiment.boom: raise RuntimeError("BOOM!") mask, energy = elp.extremal_perturbation( self.model, x, class_id, areas=areas, num_levels=8, step=7, sigma=7 * 3, max_iter=800, debug=self.debug > 0, jitter=True, smooth=0.09, resize=image_size, perturbation='blur', reward_func=elp.simple_reward, variant=elp.PRESERVE_VARIANT, ) saliency = mask.sum(dim=0, keepdim=True) point = _saliency_to_point(saliency) info = { 'saliency': saliency, 'mask': mask, 'areas': areas, 'energy': energy } else: assert False if False: plt.figure() plt.subplot(1, 2, 1) imsc(saliency[0]) plt.plot(point[0, 0], point[0, 1], 'ro') plt.subplot(1, 2, 2) imsc(x[0]) plt.pause(0) results['pointing'][class_id] = self.pointing.evaluate( y[0], class_id, point[0]) results['pointing_difficult'][ class_id] = self.pointing_difficult.evaluate( y[0], class_id, point[0]) if self.log > 0: image_name = self.data.as_image_name(y[0]) mongo_save( self.db, self.experiment.name, f"{image_name}-{class_id}", data_to_mongo({ 'image_name': image_name, 'class_id': class_id, 'pointing': results['pointing'][class_id], 'pointing_difficult': results['pointing_difficult'][class_id], })) if self.log > 1: mongo_save(self.db, str(self.experiment.name) + "-details", f"{image_name}-{class_id}", data_to_mongo(info)) return results except Exception as ex: raise ProcessingError(self, self.experiment, self.model, x, y, class_id, image_size) from ex
def extremal_perturbation(model, input, target, areas=[0.1], perturbation=BLUR_PERTURBATION, max_iter=800, num_levels=8, step=7, sigma=21, jitter=True, variant=PRESERVE_VARIANT, print_iter=None, debug=False, reward_func=simple_reward, resize=False, resize_mode='bilinear', smooth=0): r"""Compute a set of extremal perturbations. The function takes a :attr:`model`, an :attr:`input` tensor :math:`x` of size :math:`1\times C\times H\times W`, and a :attr:`target` activation channel. It produces as output a :math:`K\times C\times H\times W` tensor where :math:`K` is the number of specified :attr:`areas`. Each mask, which has approximately the specified area, is searched in order to maximise the (spatial average of the) activations in channel :attr:`target`. Alternative objectives can be specified via :attr:`reward_func`. Args: model (:class:`torch.nn.Module`): model. input (:class:`torch.Tensor`): input tensor. target (int): target channel. areas (float or list of floats, optional): list of target areas for saliency masks. Defaults to `[0.1]`. perturbation (str, optional): :ref:`ep_perturbations`. max_iter (int, optional): number of iterations for optimizing the masks. num_levels (int, optional): number of buckets with which to discretize and linearly interpolate the perturbation (see :class:`Perturbation`). Defaults to 8. step (int, optional): mask step (see :class:`MaskGenerator`). Defaults to 7. sigma (float, optional): mask smoothing (see :class:`MaskGenerator`). Defaults to 21. jitter (bool, optional): randomly flip the image horizontally at each iteration. Defaults to True. variant (str, optional): :ref:`ep_variants`. Defaults to :attr:`PRESERVE_VARIANT`. print_iter (int, optional): frequency with which to print losses. Defaults to None. debug (bool, optional): If True, generate debug plots. reward_func (function, optional): function that generates reward tensor to backpropagate. resize (bool, optional): If True, upsamples the masks the same size as :attr:`input`. It is also possible to specify a pair (width, height) for a different size. Defaults to False. resize_mode (str, optional): Upsampling method to use. Defaults to ``'bilinear'``. smooth (float, optional): Apply Gaussian smoothing to the masks after computing them. Defaults to 0. Returns: A tuple containing the masks and the energies. The masks are stored as a :class:`torch.Tensor` of dimension """ if isinstance(areas, float): areas = [areas] momentum = 0.9 learning_rate = 0.01 regul_weight = 300 device = input.device regul_weight_last = max(regul_weight / 2, 1) if debug: print(f"extremal_perturbation:\n" f"- target: {target}\n" f"- areas: {areas}\n" f"- variant: {variant}\n" f"- max_iter: {max_iter}\n" f"- step/sigma: {step}, {sigma}\n" f"- image size: {list(input.shape)}\n" f"- reward function: {reward_func.__name__}") # Disable gradients for model parameters. # TODO(av): undo on leaving the function. for p in model.parameters(): p.requires_grad_(False) # Get the perturbation operator. # The perturbation can be applied at any layer of the network (depth). perturbation = Perturbation(input, num_levels=num_levels, type=perturbation).to(device) perturbation_str = '\n '.join(perturbation.__str__().split('\n')) if debug: print(f"- {perturbation_str}") # Prepare the mask generator. shape = perturbation.pyramid.shape[2:] mask_generator = MaskGenerator(shape, step, sigma).to(device) h, w = mask_generator.shape_in pmask = torch.ones(len(areas), 1, h, w).to(device) if debug: print(f"- mask resolution:\n {pmask.shape}") # Prepare reference area vector. max_area = np.prod(mask_generator.shape_out) reference = torch.ones(len(areas), max_area).to(device) for i, a in enumerate(areas): reference[i, :int(max_area * (1 - a))] = 0 # Initialize optimizer. optimizer = optim.SGD([pmask], lr=learning_rate, momentum=momentum, dampening=momentum) hist = torch.zeros((len(areas), 2, 0)) for t in range(max_iter): pmask.requires_grad_(True) # Generate the mask. mask_, mask = mask_generator.generate(pmask) # Apply the mask. if variant == DELETE_VARIANT: x = perturbation.apply(1 - mask_) elif variant == PRESERVE_VARIANT: x = perturbation.apply(mask_) elif variant == DUAL_VARIANT: x = torch.cat(( perturbation.apply(mask_), perturbation.apply(1 - mask_), ), dim=0) else: assert False # Apply jitter to the masked data. if jitter and t % 2 == 0: x = torch.flip(x, dims=(3, )) # Evaluate the model on the masked data. y = model(x) # Get reward. reward = reward_func(y, target, variant=variant) # Reshape reward and average over spatial dimensions. reward = reward.reshape(len(areas), -1).mean(dim=1) # Area regularization. mask_sorted = mask.reshape(len(areas), -1).sort(dim=1)[0] regul = -((mask_sorted - reference)**2).mean(dim=1) * regul_weight energy = (reward + regul).sum() # Gradient step. optimizer.zero_grad() (-energy).backward() optimizer.step() pmask.data = pmask.data.clamp(0, 1) # Record energy. hist = torch.cat((hist, torch.cat((reward.detach().cpu().view( -1, 1, 1), regul.detach().cpu().view(-1, 1, 1)), dim=1)), dim=2) # Adjust the regulariser/area constraint weight. regul_weight *= 1.0035 # Diagnostics. debug_this_iter = debug and (t in (0, max_iter - 1) or regul_weight / regul_weight_last >= 2) if (print_iter is not None and t % print_iter == 0) or debug_this_iter: print("[{:04d}/{:04d}]".format(t + 1, max_iter), end="") for i, area in enumerate(areas): print(" [area:{:.2f} loss:{:.2f} reg:{:.2f}]".format( area, hist[i, 0, -1], hist[i, 1, -1]), end="") print() if debug_this_iter: regul_weight_last = regul_weight for i, a in enumerate(areas): plt.figure(i, figsize=(20, 6)) plt.clf() ncols = 4 if variant == DUAL_VARIANT else 3 plt.subplot(1, ncols, 1) plt.plot(hist[i, 0].numpy()) plt.plot(hist[i, 1].numpy()) plt.plot(hist[i].sum(dim=0).numpy()) plt.legend(('energy', 'regul', 'both')) plt.title(f'target area:{a:.2f}') plt.subplot(1, ncols, 2) imsc(mask[i], lim=[0, 1]) plt.title(f"min:{mask[i].min().item():.2f}" f" max:{mask[i].max().item():.2f}" f" area:{mask[i].sum() / mask[i].numel():.2f}") plt.subplot(1, ncols, 3) imsc(x[i]) if variant == DUAL_VARIANT: plt.subplot(1, ncols, 4) imsc(x[i + len(areas)]) plt.pause(0.001) mask_ = mask_.detach() # Resize saliency map. mask_ = resize_saliency(input, mask_, resize, mode=resize_mode) # Smooth saliency map. if smooth > 0: mask_ = imsmooth(mask_, sigma=smooth * min(mask_.shape[2:]), padding_mode='constant') return mask_, hist