def __init__(self, input, num_levels=8, max_blur=20, type=BLUR_PERTURBATION): self.type = type self.num_levels = num_levels self.pyramid = [] assert num_levels >= 2 assert max_blur > 0 with torch.no_grad(): for sigma in torch.linspace(0, 1, self.num_levels): if type == BLUR_PERTURBATION: y = imsmooth(input, sigma=(1 - sigma) * max_blur) elif type == FADE_PERTURBATION: y = input * sigma else: assert False self.pyramid.append(y) self.pyramid = torch.cat(self.pyramid, dim=0)
def extremal_perturbation(model, input, target, areas=[0.1], perturbation=BLUR_PERTURBATION, max_iter=800, num_levels=8, step=7, sigma=21, jitter=True, variant=PRESERVE_VARIANT, print_iter=None, debug=False, reward_func=simple_reward, resize=False, resize_mode='bilinear', smooth=0): r"""Compute a set of extremal perturbations. The function takes a :attr:`model`, an :attr:`input` tensor :math:`x` of size :math:`1\times C\times H\times W`, and a :attr:`target` activation channel. It produces as output a :math:`K\times C\times H\times W` tensor where :math:`K` is the number of specified :attr:`areas`. Each mask, which has approximately the specified area, is searched in order to maximise the (spatial average of the) activations in channel :attr:`target`. Alternative objectives can be specified via :attr:`reward_func`. Args: model (:class:`torch.nn.Module`): model. input (:class:`torch.Tensor`): input tensor. target (int): target channel. areas (float or list of floats, optional): list of target areas for saliency masks. Defaults to `[0.1]`. perturbation (str, optional): :ref:`ep_perturbations`. max_iter (int, optional): number of iterations for optimizing the masks. num_levels (int, optional): number of buckets with which to discretize and linearly interpolate the perturbation (see :class:`Perturbation`). Defaults to 8. step (int, optional): mask step (see :class:`MaskGenerator`). Defaults to 7. sigma (float, optional): mask smoothing (see :class:`MaskGenerator`). Defaults to 21. jitter (bool, optional): randomly flip the image horizontally at each iteration. Defaults to True. variant (str, optional): :ref:`ep_variants`. Defaults to :attr:`PRESERVE_VARIANT`. print_iter (int, optional): frequency with which to print losses. Defaults to None. debug (bool, optional): If True, generate debug plots. reward_func (function, optional): function that generates reward tensor to backpropagate. resize (bool, optional): If True, upsamples the masks the same size as :attr:`input`. It is also possible to specify a pair (width, height) for a different size. Defaults to False. resize_mode (str, optional): Upsampling method to use. Defaults to ``'bilinear'``. smooth (float, optional): Apply Gaussian smoothing to the masks after computing them. Defaults to 0. Returns: A tuple containing the masks and the energies. The masks are stored as a :class:`torch.Tensor` of dimension """ if isinstance(areas, float): areas = [areas] momentum = 0.9 learning_rate = 0.01 regul_weight = 300 device = input.device regul_weight_last = max(regul_weight / 2, 1) if debug: print(f"extremal_perturbation:\n" f"- target: {target}\n" f"- areas: {areas}\n" f"- variant: {variant}\n" f"- max_iter: {max_iter}\n" f"- step/sigma: {step}, {sigma}\n" f"- image size: {list(input.shape)}\n" f"- reward function: {reward_func.__name__}") # Disable gradients for model parameters. # TODO(av): undo on leaving the function. for p in model.parameters(): p.requires_grad_(False) # Get the perturbation operator. # The perturbation can be applied at any layer of the network (depth). perturbation = Perturbation(input, num_levels=num_levels, type=perturbation).to(device) perturbation_str = '\n '.join(perturbation.__str__().split('\n')) if debug: print(f"- {perturbation_str}") # Prepare the mask generator. shape = perturbation.pyramid.shape[2:] mask_generator = MaskGenerator(shape, step, sigma).to(device) h, w = mask_generator.shape_in pmask = torch.ones(len(areas), 1, h, w).to(device) if debug: print(f"- mask resolution:\n {pmask.shape}") # Prepare reference area vector. max_area = np.prod(mask_generator.shape_out) reference = torch.ones(len(areas), max_area).to(device) for i, a in enumerate(areas): reference[i, :int(max_area * (1 - a))] = 0 # Initialize optimizer. optimizer = optim.SGD([pmask], lr=learning_rate, momentum=momentum, dampening=momentum) hist = torch.zeros((len(areas), 2, 0)) for t in range(max_iter): pmask.requires_grad_(True) # Generate the mask. mask_, mask = mask_generator.generate(pmask) # Apply the mask. if variant == DELETE_VARIANT: x = perturbation.apply(1 - mask_) elif variant == PRESERVE_VARIANT: x = perturbation.apply(mask_) elif variant == DUAL_VARIANT: x = torch.cat(( perturbation.apply(mask_), perturbation.apply(1 - mask_), ), dim=0) else: assert False # Apply jitter to the masked data. if jitter and t % 2 == 0: x = torch.flip(x, dims=(3, )) # Evaluate the model on the masked data. y = model(x) # Get reward. reward = reward_func(y, target, variant=variant) # Reshape reward and average over spatial dimensions. reward = reward.reshape(len(areas), -1).mean(dim=1) # Area regularization. mask_sorted = mask.reshape(len(areas), -1).sort(dim=1)[0] regul = -((mask_sorted - reference)**2).mean(dim=1) * regul_weight energy = (reward + regul).sum() # Gradient step. optimizer.zero_grad() (-energy).backward() optimizer.step() pmask.data = pmask.data.clamp(0, 1) # Record energy. hist = torch.cat((hist, torch.cat((reward.detach().cpu().view( -1, 1, 1), regul.detach().cpu().view(-1, 1, 1)), dim=1)), dim=2) # Adjust the regulariser/area constraint weight. regul_weight *= 1.0035 # Diagnostics. debug_this_iter = debug and (t in (0, max_iter - 1) or regul_weight / regul_weight_last >= 2) if (print_iter is not None and t % print_iter == 0) or debug_this_iter: print("[{:04d}/{:04d}]".format(t + 1, max_iter), end="") for i, area in enumerate(areas): print(" [area:{:.2f} loss:{:.2f} reg:{:.2f}]".format( area, hist[i, 0, -1], hist[i, 1, -1]), end="") print() if debug_this_iter: regul_weight_last = regul_weight for i, a in enumerate(areas): plt.figure(i, figsize=(20, 6)) plt.clf() ncols = 4 if variant == DUAL_VARIANT else 3 plt.subplot(1, ncols, 1) plt.plot(hist[i, 0].numpy()) plt.plot(hist[i, 1].numpy()) plt.plot(hist[i].sum(dim=0).numpy()) plt.legend(('energy', 'regul', 'both')) plt.title(f'target area:{a:.2f}') plt.subplot(1, ncols, 2) imsc(mask[i], lim=[0, 1]) plt.title(f"min:{mask[i].min().item():.2f}" f" max:{mask[i].max().item():.2f}" f" area:{mask[i].sum() / mask[i].numel():.2f}") plt.subplot(1, ncols, 3) imsc(x[i]) if variant == DUAL_VARIANT: plt.subplot(1, ncols, 4) imsc(x[i + len(areas)]) plt.pause(0.001) mask_ = mask_.detach() # Resize saliency map. mask_ = resize_saliency(input, mask_, resize, mode=resize_mode) # Smooth saliency map. if smooth > 0: mask_ = imsmooth(mask_, sigma=smooth * min(mask_.shape[2:]), padding_mode='constant') return mask_, hist
def saliency(model, input, target, saliency_layer='', resize=False, resize_mode='bilinear', smooth=0, context_builder=NullContext, gradient_to_saliency=gradient_to_saliency, get_backward_gradient=get_backward_gradient, debug=False): """Apply a backprop-based attribution method to an image. The saliency method is specified by a suitable context factory :attr:`context_builder`. This context is used to modify the backpropagation algorithm to match a given visualization method. This: 1. Attaches a probe to the output tensor of :attr:`saliency_layer`, which must be a layer in :attr:`model`. If no such layer is specified, it selects the input tensor to :attr:`model`. 2. Uses the function :attr:`get_backward_gradient` to obtain a gradient for the output tensor of the model. This function is passed as input the output tensor as well as the parameter :attr:`target`. By default, the :func:`get_backward_gradient` function is used. The latter generates as gradient a one-hot vector selecting :attr:`target`, usually the index of the class predicted by :attr:`model`. 3. Evaluates :attr:`model` on :attr:`input` and then computes the pseudo-gradient of the model with respect the selected tensor. This calculation is controlled by :attr:`context_builder`. 4. Extract the pseudo-gradient at the selected tensor as a raw saliency map. 5. Call :attr:`gradient_to_saliency` to obtain an actual saliency map. This defaults to :func:`gradient_to_saliency` that takes the maximum absolute value along the channel dimension of the pseudo-gradient tensor. 6. Optionally resizes the saliency map thus obtained. By default, this uses bilinear interpolation and resizes the saliency to the same spatial dimension of :attr:`input`. 7. Optionally applies a Gaussian filter to the resized saliency map. The standard deviation :attr:`sigma` of this filter is measured as a fraction of the maxmum spatial dimension of the resized saliency map. 8. Removes the probe. 9. Returns the saliency map or optionally a tuple with the saliency map and a OrderedDict of Probe objects for all modules in the model, which can be used for debugging. Args: model (:class:`torch.nn.Module`): a model. input (:class:`torch.Tensor`): input tensor. target (int or :class:`torch.Tensor`): target label(s). saliency_layer (str or :class:`torch.nn.Module`, optional): name of the saliency layer (str) or the layer itself (:class:`torch.nn.Module`) in the model at which to visualize. Default: ``''`` (visualize at input). resize (bool or tuple, optional): if True, upsample saliency map to the same size as :attr:`input`. It is also possible to specify a pair (width, height) for a different size. Default: ``False``. resize_mode (str, optional): upsampling method to use. Default: ``'bilinear'``. smooth (float, optional): amount of Gaussian smoothing to apply to the saliency map. Default: ``0``. context_builder (type, optional): type of context to use. Default: :class:`NullContext`. gradient_to_saliency (function, optional): function that converts the pseudo-gradient signal to a saliency map. Default: :func:`gradient_to_saliency`. get_backward_gradient (function, optional): function that generates gradient tensor to backpropagate. Default: :func:`get_backward_gradient`. debug (bool, optional): if True, also return an :class:`collections.OrderedDict` of :class:`Probe` objects for all modules in the model. Default: ``False``. Returns: :class:`torch.Tensor` or tuple: If :attr:`debug` is False, returns a :class:`torch.Tensor` saliency map at :attr:`saliency_layer`. Otherwise, returns a tuple of a :class:`torch.Tensor` saliency map at :attr:`saliency_layer` and an :class:`collections.OrderedDict` of :class:`Probe` objects for all modules in the model. """ # Clear any existing gradient. if input.grad is not None: input.grad.data.zero_() # Disable gradients for model parameters. orig_requires_grad = {} for name, param in model.named_parameters(): orig_requires_grad[name] = param.requires_grad param.requires_grad_(False) # Set model to eval mode. if model.training: orig_is_training = True model.eval() else: orig_is_training = False # Attach debug probes to every module. debug_probes = attach_debug_probes(model, debug=debug) # Attach a probe to the saliency layer. probe_target = 'input' if saliency_layer == '' else 'output' saliency_layer = get_module(model, saliency_layer) assert saliency_layer is not None, 'We could not find the saliency layer' probe = Probe(saliency_layer, target=probe_target) # Do a forward and backward pass. with context_builder(): output = model(input) backward_gradient = get_backward_gradient(output, target) output.backward(backward_gradient) # Get saliency map from gradient. saliency_map = gradient_to_saliency(probe.data[0]) # Resize saliency map. saliency_map = resize_saliency(input, saliency_map, resize, mode=resize_mode) # Smooth saliency map. if smooth > 0: saliency_map = imsmooth(saliency_map, sigma=smooth * max(saliency_map.shape[2:]), padding_mode='replicate') # Remove probe. probe.remove() # Restore gradient saving for model parameters. for name, param in model.named_parameters(): param.requires_grad_(orig_requires_grad[name]) # Restore model's original mode. if orig_is_training: model.train() if debug: return saliency_map, debug_probes else: return saliency_map