def process_imgs(self, img, tgt, batch_size, num_pixels, stride): tgt = int(to_numpy(tgt)) # Evaluate the given explainer on the image and sort the pixel indices by the 'importance values' pixel_importance = to_numpy(self.explainer.attribute(img.cuda(), target=tgt))[0].sum(0) idcs = np.argsort(pixel_importance.flatten()) # Direction of the sorting if self.config["direction"] == "most": idcs = idcs[::-1] # Only delete the first num_pixels idcs = idcs[:num_pixels] # Compute the corresponding masks for deleting pixels in the given order positions = np.array(np.unravel_index(idcs, pixel_importance.shape)).T # First mask uses all pixels masks = [torch.ones(1, *pixel_importance.shape)] for h, w in positions: # Delete one additional position at a time mask = masks[-1].clone() mask[0, h, w] = 0 masks.append(mask) # In order to speed up evaluation only evaluate masks at a stride (skipping the masks in between) masks = torch.cat([m for m_idx, m in enumerate(masks) if (m_idx % stride) == 0], dim=0)[:, None] # Compute the probabilities of the target class for the masked images # For efficiency, do this in batch mode. pert_out = torch.cat([self.masked_predict(img, masks[idx * batch_size: (idx + 1) * batch_size]) for idx in range(int(np.ceil(len(masks) / batch_size)))], dim=0)[None, :, tgt] return pert_out.cpu()
def analysis(self): batch_size, n_imgs, sample_points, num_pixels = (self.config["batch_size"], self.config["n_imgs"], self.config["sample_points"], self.config["num_pixels"]) trainer = self.trainer loader = trainer.data.get_test_loader() results = [] img_count = 0 img_dims = None for img, tgt in loader: if img_dims is None: img_dims = np.prod(img.shape[-2:]) max_idx = int(num_pixels * img_dims) stride = int(max_idx // sample_points) new_count = img_count + len(img) # Only evaluate n_imgs if new_count > n_imgs: img = img[:-(new_count - n_imgs)] tgt = tgt[:-(new_count - n_imgs)] img_count += len(img) # Evaluate the pixel perturbation for the sampled images results.append(self.process_imgs(img, tgt.argmax(1), batch_size, max_idx, stride)) print("Done with {percent:5>.2f}%".format(percent=img_count * 100 / n_imgs), flush=True) if img_count == n_imgs: break return {"perturbation_results": to_numpy(torch.cat(results, dim=0).mean(0))}
def select_classes(target, all_matrices, n): """ ONLY USED TO REGULARISE MATRICES. Method for creating an indexing tensor to select a (sub)set of the classes to have their linear mapping reversed. Args: target: Ground truth target class as one-hot encoding. all_matrices (bool): If True, all matrices are selected. n: if not all_matrices, n - 1 classes are randomly chosen. Returns: an index tensor with (batch_size, n) of the n chosen classes per example. If all_matrices, then the standard order of the class indices is kept. Otherwise, the ground truth class will be the first in the list. """ num_classes = target.shape[1] target = to_numpy(target) if all_matrices: return len(target) * [slice(None, None)] if n == 1: return target.argmax(1)[:, None] # Initialise matrix for the choices per sample out = np.zeros_like(to_numpy(target)) if not (target.sum(1) > 1).any(): # Let first entry always be the correct class. out[:, 0] = to_numpy(target.argmax(1)) else: # Choose one of the possible entries for each batch item active_tgts = torch.stack(torch.where(target)).T for batch_idx in range(len(target)): sample_tgts = active_tgts[active_tgts[:, 0] == batch_idx, 1] out[batch_idx, 0] = sample_tgts[np.random.randint(len(sample_tgts))].item() out[:, 1:] = np.array([ np.random.permutation(np.r_[0:t, t + 1:num_classes]) for t in out[:, 0] ]) return out
def attribute_selection(self, img, targets): """ Calls the attribution method for all targets in the list of targets """ targets = np.array(to_numpy(targets), dtype=int).reshape( len(img), -1) # Make sure it is a numpy array # Out is of size (bs, number of targets per image, in_channels, height, width) out = torch.zeros(*targets.shape[:2], *img.shape[1:]).type(torch.FloatTensor) for tgt_idx in range(targets.shape[1]): out[:, tgt_idx] = self.attribute( img, target=(targets[:, tgt_idx]).tolist()).detach().cpu() return out.reshape(-1, *img.shape[1:]).cuda()
def get_imgs_and_atts(trainer, video, class_idx): if class_idx == -1: class_idx = most_predicted(trainer, video) atts = [] imgs = [] for img in video: img = trainer.data.get_test_loader().dataset.transform( PIL.Image.fromarray(img)).cuda()[:][None] att = trainer.attribute(img, class_idx)[0].sum(0) atts.append(att2img(att)) imgs.append( np.array(to_numpy(img[0].permute(1, 2, 0)) * 255, dtype=np.uint8)) return imgs, atts
def plot_contribution_map(contribution_map, ax=None, vrange=None, vmin=None, vmax=None, hide_ticks=True, cmap="bwr", percentile=100): """ Visualises a contribution map, i.e., a matrix assigning individual weights to each spatial location. As default, this shows a contribution map with the "bwr" colormap and chooses vmin and vmax so that the map ranges from (-max(abs(contribution_map), max(abs(contribution_map)). Args: contribution_map: (H, W) matrix to visualise as contributions. ax: axis on which to plot. If None, a new figure is created. vrange: If None, the colormap ranges from -v to v, with v being the maximum absolute value in the map. If provided, it will range from -vrange to vrange, as long as either one of the boundaries is not overwritten by vmin or vmax. vmin: Manually overwrite the minimum value for the colormap range instead of using -vrange. vmax: Manually overwrite the maximum value for the colormap range instead of using vrange. hide_ticks: Sets the axis ticks to [] cmap: colormap to use for the contribution map plot. percentile: If percentile is given, this will be used as a cut-off for the attribution maps. Returns: The axis on which the contribution map was plotted. """ assert len( contribution_map.shape ) == 2, "Contribution map is supposed to only have spatial dimensions.." contribution_map = to_numpy(contribution_map) cutoff = np.percentile(np.abs(contribution_map), percentile) contribution_map = np.clip(contribution_map, -cutoff, cutoff) if ax is None: fig, ax = plt.subplots(1) if vrange is None or vrange == "auto": vrange = np.max(np.abs(contribution_map.flatten())) im = ax.imshow(contribution_map, cmap=cmap, vmin=-vrange if vmin is None else vmin, vmax=vrange if vmax is None else vmax) if hide_ticks: ax.set_xticks([]) ax.set_yticks([]) return ax, im
def attribute(self, img, target, return_all=False): with torch.no_grad(): explanation = self.explain_instance( to_numpy(img[0].permute(1, 2, 0)), self.pred_f, labels=range(self.num_classes), top_labels=None, num_samples=self.num_samples, segmentation_fn=self.segmenter, batch_size=self.batch_size) if return_all: return torch.cat([ torch.from_numpy(np.array(explanation.get_image_and_mask( t, hide_rest=True, positive_only=True, num_features=self.num_features)[1][None, None], dtype=float)) for t in range(self.num_classes)], dim=0) return torch.from_numpy(np.array(explanation.get_image_and_mask( int(np.array(target)), hide_rest=True, positive_only=True, num_features=self.num_features)[1][None, None], dtype=float))
def __call__(self, model_out, model_in, target): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): target = target.argmax(1) maxk = max(self.topk) batch_size = target.size(0) _, pred = model_out.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = dict() for k in self.topk: correct_k = correct[:k].view(-1).float().sum(0) _res = float(to_numpy((correct_k / batch_size))) if k == 1: res["accuracy"] = _res else: res["Acc@{k}".format(k=k)] = _res return res
def analysis(self): sample_size, n_imgs = self.config["sample_size"], self.config["n_imgs"] trainer = self.trainer loader = trainer.data.get_test_loader() fixed_indices = self.get_sorted_indices() metric = [] explainer = self.explainer offset = 0 single_shape = loader.dataset[0][0].shape[-1] for count in range(sample_size): multi_img, tgts, offset = self.make_multi_image( n_imgs, loader, offset=offset, fixed_indices=fixed_indices) # calculate the attributions for all classes that are participating and only save positive contribs attributions = explainer.attribute_selection(multi_img, tgts).sum( 1, keepdim=True).clamp(0) # Calculate the relative amount of attributions per region. Use avg_pool for simplicity. with torch.no_grad(): contribs = F.avg_pool2d(attributions, single_shape, stride=single_shape).permute( 0, 1, 3, 2).reshape(attributions.shape[0], -1) total = contribs.sum(1, keepdim=True) contribs = to_numpy( torch.where(total * contribs > 0, contribs / total, torch.zeros_like(contribs))) metric.append( [contrib[idx] for idx, contrib in enumerate(contribs)]) print("{:>6.2f}% of processing complete".format( 100 * (count + 1.) / sample_size), flush=True) result = np.array(metric).flatten() print("Percentiles of localisation accuracy (25, 50, 75, 100): ", np.percentile(result, [25, 50, 75, 100])) return {"localisation_metric": result}
def att2img(attribution): return np.uint8( cm.bwr((np.clip(to_numpy(attribution) / MAXV, -1, 1) + 1) / 2) * 255)[:, :, :3]
def pred_f(self, input_samples): return to_numpy(self.trainer.predict(self.make_input_tensor(input_samples)))
def default_eval_batch(model_out, model_in, tgt): _ = model_in return {"accuracy": to_numpy(tgt.argmax(1) == model_out.argmax(1)).mean()}
def eval_batch_cross_entropy(model_out, model_in, tgt): _ = model_in return {"accuracy": to_numpy(tgt.argmax(1) == model_out.argmax(1)).mean()}