Ejemplo n.º 1
0
 def __init__(self, input, num_levels=8, max_blur=20, type=BLUR_PERTURBATION):
     self.type = type
     self.num_levels = num_levels
     self.pyramid = []
     assert num_levels >= 2
     assert max_blur > 0
     with torch.no_grad():
         for sigma in torch.linspace(0, 1, self.num_levels):
             if type == BLUR_PERTURBATION:
                 y = imsmooth(input, sigma=(1 - sigma) * max_blur)
             elif type == FADE_PERTURBATION:
                 y = input * sigma
             else:
                 assert False
             self.pyramid.append(y)
         self.pyramid = torch.cat(self.pyramid, dim=0)
def extremal_perturbation(model,
                          input,
                          target,
                          areas=[0.1],
                          perturbation=BLUR_PERTURBATION,
                          max_iter=800,
                          num_levels=8,
                          step=7,
                          sigma=21,
                          jitter=True,
                          variant=PRESERVE_VARIANT,
                          print_iter=None,
                          debug=False,
                          reward_func=simple_reward,
                          resize=False,
                          resize_mode='bilinear',
                          smooth=0):
    r"""Compute a set of extremal perturbations.

    The function takes a :attr:`model`, an :attr:`input` tensor :math:`x`
    of size :math:`1\times C\times H\times W`, and a :attr:`target`
    activation channel. It produces as output a
    :math:`K\times C\times H\times W` tensor where :math:`K` is the number
    of specified :attr:`areas`.

    Each mask, which has approximately the specified area, is searched
    in order to maximise the (spatial average of the) activations
    in channel :attr:`target`. Alternative objectives can be specified
    via :attr:`reward_func`.

    Args:
        model (:class:`torch.nn.Module`): model.
        input (:class:`torch.Tensor`): input tensor.
        target (int): target channel.
        areas (float or list of floats, optional): list of target areas for saliency
            masks. Defaults to `[0.1]`.
        perturbation (str, optional): :ref:`ep_perturbations`.
        max_iter (int, optional): number of iterations for optimizing the masks.
        num_levels (int, optional): number of buckets with which to discretize
            and linearly interpolate the perturbation
            (see :class:`Perturbation`). Defaults to 8.
        step (int, optional): mask step (see :class:`MaskGenerator`).
            Defaults to 7.
        sigma (float, optional): mask smoothing (see :class:`MaskGenerator`).
            Defaults to 21.
        jitter (bool, optional): randomly flip the image horizontally at each iteration.
            Defaults to True.
        variant (str, optional): :ref:`ep_variants`. Defaults to
            :attr:`PRESERVE_VARIANT`.
        print_iter (int, optional): frequency with which to print losses.
            Defaults to None.
        debug (bool, optional): If True, generate debug plots.
        reward_func (function, optional): function that generates reward tensor
            to backpropagate.
        resize (bool, optional): If True, upsamples the masks the same size
            as :attr:`input`. It is also possible to specify a pair
            (width, height) for a different size. Defaults to False.
        resize_mode (str, optional): Upsampling method to use. Defaults to
            ``'bilinear'``.
        smooth (float, optional): Apply Gaussian smoothing to the masks after
            computing them. Defaults to 0.

    Returns:
        A tuple containing the masks and the energies.
        The masks are stored as a :class:`torch.Tensor`
        of dimension
    """
    if isinstance(areas, float):
        areas = [areas]
    momentum = 0.9
    learning_rate = 0.01
    regul_weight = 300
    device = input.device

    regul_weight_last = max(regul_weight / 2, 1)

    if debug:
        print(f"extremal_perturbation:\n"
              f"- target: {target}\n"
              f"- areas: {areas}\n"
              f"- variant: {variant}\n"
              f"- max_iter: {max_iter}\n"
              f"- step/sigma: {step}, {sigma}\n"
              f"- image size: {list(input.shape)}\n"
              f"- reward function: {reward_func.__name__}")

    # Disable gradients for model parameters.
    # TODO(av): undo on leaving the function.
    for p in model.parameters():
        p.requires_grad_(False)

    # Get the perturbation operator.
    # The perturbation can be applied at any layer of the network (depth).
    perturbation = Perturbation(input,
                                num_levels=num_levels,
                                type=perturbation).to(device)

    perturbation_str = '\n  '.join(perturbation.__str__().split('\n'))
    if debug:
        print(f"- {perturbation_str}")

    # Prepare the mask generator.
    shape = perturbation.pyramid.shape[2:]
    mask_generator = MaskGenerator(shape, step, sigma).to(device)
    h, w = mask_generator.shape_in
    pmask = torch.ones(len(areas), 1, h, w).to(device)
    if debug:
        print(f"- mask resolution:\n  {pmask.shape}")

    # Prepare reference area vector.
    max_area = np.prod(mask_generator.shape_out)
    reference = torch.ones(len(areas), max_area).to(device)
    for i, a in enumerate(areas):
        reference[i, :int(max_area * (1 - a))] = 0

    # Initialize optimizer.
    optimizer = optim.SGD([pmask],
                          lr=learning_rate,
                          momentum=momentum,
                          dampening=momentum)
    hist = torch.zeros((len(areas), 2, 0))

    for t in range(max_iter):
        pmask.requires_grad_(True)

        # Generate the mask.
        mask_, mask = mask_generator.generate(pmask)

        # Apply the mask.
        if variant == DELETE_VARIANT:
            x = perturbation.apply(1 - mask_)
        elif variant == PRESERVE_VARIANT:
            x = perturbation.apply(mask_)
        elif variant == DUAL_VARIANT:
            x = torch.cat((
                perturbation.apply(mask_),
                perturbation.apply(1 - mask_),
            ),
                          dim=0)
        else:
            assert False

        # Apply jitter to the masked data.
        if jitter and t % 2 == 0:
            x = torch.flip(x, dims=(3, ))

        # Evaluate the model on the masked data.
        y = model(x)

        # Get reward.
        reward = reward_func(y, target, variant=variant)

        # Reshape reward and average over spatial dimensions.
        reward = reward.reshape(len(areas), -1).mean(dim=1)

        # Area regularization.
        mask_sorted = mask.reshape(len(areas), -1).sort(dim=1)[0]
        regul = -((mask_sorted - reference)**2).mean(dim=1) * regul_weight
        energy = (reward + regul).sum()

        # Gradient step.
        optimizer.zero_grad()
        (-energy).backward()
        optimizer.step()

        pmask.data = pmask.data.clamp(0, 1)

        # Record energy.
        hist = torch.cat((hist,
                          torch.cat((reward.detach().cpu().view(
                              -1, 1, 1), regul.detach().cpu().view(-1, 1, 1)),
                                    dim=1)),
                         dim=2)

        # Adjust the regulariser/area constraint weight.
        regul_weight *= 1.0035

        # Diagnostics.
        debug_this_iter = debug and (t in (0, max_iter - 1)
                                     or regul_weight / regul_weight_last >= 2)

        if (print_iter is not None and t % print_iter == 0) or debug_this_iter:
            print("[{:04d}/{:04d}]".format(t + 1, max_iter), end="")
            for i, area in enumerate(areas):
                print(" [area:{:.2f} loss:{:.2f} reg:{:.2f}]".format(
                    area, hist[i, 0, -1], hist[i, 1, -1]),
                      end="")
            print()

        if debug_this_iter:
            regul_weight_last = regul_weight
            for i, a in enumerate(areas):
                plt.figure(i, figsize=(20, 6))
                plt.clf()
                ncols = 4 if variant == DUAL_VARIANT else 3
                plt.subplot(1, ncols, 1)
                plt.plot(hist[i, 0].numpy())
                plt.plot(hist[i, 1].numpy())
                plt.plot(hist[i].sum(dim=0).numpy())
                plt.legend(('energy', 'regul', 'both'))
                plt.title(f'target area:{a:.2f}')
                plt.subplot(1, ncols, 2)
                imsc(mask[i], lim=[0, 1])
                plt.title(f"min:{mask[i].min().item():.2f}"
                          f" max:{mask[i].max().item():.2f}"
                          f" area:{mask[i].sum() / mask[i].numel():.2f}")
                plt.subplot(1, ncols, 3)
                imsc(x[i])
                if variant == DUAL_VARIANT:
                    plt.subplot(1, ncols, 4)
                    imsc(x[i + len(areas)])
                plt.pause(0.001)

    mask_ = mask_.detach()

    # Resize saliency map.
    mask_ = resize_saliency(input, mask_, resize, mode=resize_mode)

    # Smooth saliency map.
    if smooth > 0:
        mask_ = imsmooth(mask_,
                         sigma=smooth * min(mask_.shape[2:]),
                         padding_mode='constant')

    return mask_, hist
Ejemplo n.º 3
0
def saliency(model,
             input,
             target,
             saliency_layer='',
             resize=False,
             resize_mode='bilinear',
             smooth=0,
             context_builder=NullContext,
             gradient_to_saliency=gradient_to_saliency,
             get_backward_gradient=get_backward_gradient,
             debug=False):
    """Apply a backprop-based attribution method to an image.

    The saliency method is specified by a suitable context factory
    :attr:`context_builder`. This context is used to modify the backpropagation
    algorithm to match a given visualization method. This:

    1. Attaches a probe to the output tensor of :attr:`saliency_layer`,
       which must be a layer in :attr:`model`. If no such layer is specified,
       it selects the input tensor to :attr:`model`.

    2. Uses the function :attr:`get_backward_gradient` to obtain a gradient
       for the output tensor of the model. This function is passed
       as input the output tensor as well as the parameter :attr:`target`.
       By default, the :func:`get_backward_gradient` function is used.
       The latter generates as gradient a one-hot vector selecting
       :attr:`target`, usually the index of the class predicted by
       :attr:`model`.

    3. Evaluates :attr:`model` on :attr:`input` and then computes the
       pseudo-gradient of the model with respect the selected tensor. This
       calculation is controlled by :attr:`context_builder`.

    4. Extract the pseudo-gradient at the selected tensor as a raw saliency
       map.

    5. Call :attr:`gradient_to_saliency` to obtain an actual saliency map.
       This defaults to :func:`gradient_to_saliency` that takes the maximum
       absolute value along the channel dimension of the pseudo-gradient
       tensor.

    6. Optionally resizes the saliency map thus obtained. By default,
       this uses bilinear interpolation and resizes the saliency to the same
       spatial dimension of :attr:`input`.

    7. Optionally applies a Gaussian filter to the resized saliency map.
       The standard deviation :attr:`sigma` of this filter is measured
       as a fraction of the maxmum spatial dimension of the resized
       saliency map.

    8. Removes the probe.

    9. Returns the saliency map or optionally a tuple with the saliency map
       and a OrderedDict of Probe objects for all modules in the model, which
       can be used for debugging.

    Args:
        model (:class:`torch.nn.Module`): a model.
        input (:class:`torch.Tensor`): input tensor.
        target (int or :class:`torch.Tensor`): target label(s).
        saliency_layer (str or :class:`torch.nn.Module`, optional): name of the
            saliency layer (str) or the layer itself (:class:`torch.nn.Module`)
            in the model at which to visualize. Default: ``''`` (visualize
            at input).
        resize (bool or tuple, optional): if True, upsample saliency map to the
            same size as :attr:`input`. It is also possible to specify a pair
            (width, height) for a different size. Default: ``False``.
        resize_mode (str, optional): upsampling method to use. Default:
            ``'bilinear'``.
        smooth (float, optional): amount of Gaussian smoothing to apply to the
            saliency map. Default: ``0``.
        context_builder (type, optional): type of context to use. Default:
            :class:`NullContext`.
        gradient_to_saliency (function, optional): function that converts the
            pseudo-gradient signal to a saliency map. Default:
            :func:`gradient_to_saliency`.
        get_backward_gradient (function, optional): function that generates
            gradient tensor to backpropagate. Default:
            :func:`get_backward_gradient`.
        debug (bool, optional): if True, also return an
            :class:`collections.OrderedDict` of :class:`Probe` objects for
            all modules in the model. Default: ``False``.

    Returns:
        :class:`torch.Tensor` or tuple: If :attr:`debug` is False, returns a
        :class:`torch.Tensor` saliency map at :attr:`saliency_layer`.
        Otherwise, returns a tuple of a :class:`torch.Tensor` saliency map
        at :attr:`saliency_layer` and an :class:`collections.OrderedDict`
        of :class:`Probe` objects for all modules in the model.
    """

    # Clear any existing gradient.
    if input.grad is not None:
        input.grad.data.zero_()

    # Disable gradients for model parameters.
    orig_requires_grad = {}
    for name, param in model.named_parameters():
        orig_requires_grad[name] = param.requires_grad
        param.requires_grad_(False)

    # Set model to eval mode.
    if model.training:
        orig_is_training = True
        model.eval()
    else:
        orig_is_training = False

    # Attach debug probes to every module.
    debug_probes = attach_debug_probes(model, debug=debug)

    # Attach a probe to the saliency layer.
    probe_target = 'input' if saliency_layer == '' else 'output'
    saliency_layer = get_module(model, saliency_layer)
    assert saliency_layer is not None, 'We could not find the saliency layer'
    probe = Probe(saliency_layer, target=probe_target)

    # Do a forward and backward pass.
    with context_builder():
        output = model(input)
        backward_gradient = get_backward_gradient(output, target)
        output.backward(backward_gradient)

    # Get saliency map from gradient.
    saliency_map = gradient_to_saliency(probe.data[0])

    # Resize saliency map.
    saliency_map = resize_saliency(input,
                                   saliency_map,
                                   resize,
                                   mode=resize_mode)

    # Smooth saliency map.
    if smooth > 0:
        saliency_map = imsmooth(saliency_map,
                                sigma=smooth * max(saliency_map.shape[2:]),
                                padding_mode='replicate')

    # Remove probe.
    probe.remove()

    # Restore gradient saving for model parameters.
    for name, param in model.named_parameters():
        param.requires_grad_(orig_requires_grad[name])

    # Restore model's original mode.
    if orig_is_training:
        model.train()

    if debug:
        return saliency_map, debug_probes
    else:
        return saliency_map