Beispiel #1
0
def plot_example(input,
                 saliency,
                 method,
                 category_id,
                 show_plot=False,
                 save_path=None):
    """Plot an example.

    Args:
        input (:class:`torch.Tensor`): 4D tensor containing input images.
        saliency (:class:`torch.Tensor`): 4D tensor containing saliency maps.
        method (str): name of saliency method.
        category_id (int): ID of ImageNet category.
        show_plot (bool, optional): If True, show plot. Default: ``False``.
        save_path (str, optional): Path to save figure to. Default: ``None``.
    """
    from torchray.utils import imsc
    from torchray.benchmark.datasets import IMAGENET_CLASSES

    if isinstance(category_id, int):
        category_id = [category_id]

    batch_size = len(input)

    plt.clf()
    for i in range(batch_size):
        class_i = category_id[i % len(category_id)]

        plt.subplot(batch_size, 2, 1 + 2 * i)
        imsc(input[i])
        plt.title('input image', fontsize=8)

        plt.subplot(batch_size, 2, 2 + 2 * i)
        imsc(saliency[i], interpolation='none')
        plt.title('{} for category {} ({})'.format(method,
                                                   IMAGENET_CLASSES[class_i],
                                                   class_i),
                  fontsize=8)

    # Save figure if path is specified.
    if save_path:
        save_dir = os.path.dirname(os.path.abspath(save_path))
        # Create directory if necessary.
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        ext = os.path.splitext(save_path)[1].strip('.')
        plt.savefig(save_path, format=ext, bbox_inches='tight')

    # Show plot if desired.
    if show_plot:
        plt.show()
Beispiel #2
0
def image2tensor(image_path):
    # convert image to torch tensor with shape (1 * 3 * 224 * 224)
    img = Image.open(image_path)
    p = transforms.Compose([transforms.Scale((224,224))])

    img,i = imsc(p(img),quiet=False)
    return torch.reshape(img, (1,3,224,224))
Beispiel #3
0
def torchray_multimodal_explain(image_path,text):
    # image_path = "static\\" + image_path
    model = MMBT.from_pretrained("mmbt.hateful_memes.images")
    model = model.to(torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu"))

    image_tensor = image2tensor(image_path)

    image_tensor = image_tensor.to((torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")))

    mask_, hist_, output_tensor, summary, conclusion = multi_extremal_perturbation(
                                                                    model,
                                                                    image_tensor,
                                                                    image_path,
                                                                    text,
                                                                    0,
                                                                    reward_func=contrastive_reward,
                                                                    debug=True,
                                                                    areas=[0.12])
    # summary is a higher level explanation in terms of sentence
    # conclusion is a list that contains words and their weights
    # output_tensor is the masked image

    Image = transforms.ToPILImage()(imsc(image_tensor[0],quiet=False)[0]).convert("RGB")

    Image.save("torchray.png")
    print(summary)
    return conclusion
Beispiel #4
0
            def _ep(m, x, i):
                # Extremal perturbation backprop.
                masks_1, _ = extremal_perturbation(
                    m,
                    x,
                    i,
                    reward_func=contrastive_reward,
                    debug=True,
                    areas=[area],
                )
                masks_1 = imsc(masks_1[0], quiet=True)[0][None]

                return masks_1
Beispiel #5
0
	def classify(self,image,text_input, image_tensor = None):
		'''
		Args:	
			image_path: directory of input image
			text_input : the text input Str
			image_tensor : the image torch.tensor with size (1,3,224,224)
			
		Returns :
			label of model prediction and the corresponding confidence
		'''
		
		scoreFlag = False
		if image_tensor != None:
			scoreFlag = True
			logits = self.onnx_model_forward(image_tensor,text_input)
		else:
			p = transforms.Compose([transforms.Scale((224,224))])
			image,i = imsc(p(image),quiet=True)
			image_tensor = torch.reshape(image, (1,3,224,224))
			logits = self.onnx_model_forward(image_tensor,text_input)

		if list(torch.tensor(logits).size()) != [1, 2]:
			
			if self.defaultmodel == None:
				self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
				self.defaultmodel.to(self.device)
			logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to(self.device), 0))
			

		scores = nn.functional.softmax(torch.tensor(logits), dim=1)

		if scoreFlag == True:
			return scores

		confidence, label = torch.max(scores, dim=1)

		return {"label": label.item(), "confidence": confidence.item()}
Beispiel #6
0
def hierarchical_perturbation(model,
                              input,
                              target,
                              vis=False,
                              interp_mode='nearest',
                              resize=None,
                              batch_size=32,
                              perturbation_type='fade',
                              num_cells=4):
    with torch.no_grad():
        # Get device of input (i.e., GPU).
        dev = input.device
        if dev == 'cpu':
            batch_size = 1
        bn, channels, input_y_dim, input_x_dim = input.shape
        dim = min(input_x_dim, input_y_dim)
        total_masks = 0
        depth = 0
        max_depth = int(np.log2(dim / num_cells))
        saliency = torch.zeros((1, 1, input_y_dim, input_x_dim), device=dev)
        depth_saliencies = torch.zeros(
            (max_depth + 1, 1, 1, input_y_dim, input_x_dim), device=dev)
        max_batch = batch_size

        output = F.softmax(model(input), dim=1)[:, target]

        if perturbation_type == 'blur':
            pre_b_image = blur(input.clone().cpu()).to(dev)

        while depth < max_depth:

            masks_list = []
            b_list = []
            num_cells *= 2
            depth += 1
            threshold = torch.min(saliency) + (
                (torch.max(saliency) - torch.min(saliency)) / 2)

            print('Depth: {}, {} x {} Cell'.format(depth,
                                                   input_y_dim // num_cells,
                                                   input_x_dim // num_cells))
            print('Threshold: {:.1f}'.format(threshold))
            print('Range {:.1f} to {:.1f}'.format(saliency.min(),
                                                  saliency.max()))

            y_ixs = range(-1, num_cells)
            x_ixs = range(-1, num_cells)
            x_cell_dim = input_x_dim // num_cells
            y_cell_dim = input_y_dim // num_cells

            pos_masks = 0

            for x in x_ixs:
                for y in y_ixs:
                    x1, y1 = max(0, x), max(0, y)
                    x2, y2 = min(x + 2, num_cells), min(y + 2, num_cells)
                    pos_masks += 1

                    mask = torch.zeros((1, 1, num_cells, num_cells),
                                       device=dev)
                    mask[:, :, y1:y2, x1:x2] = 1.0
                    local_saliency = F.interpolate(mask,
                                                   (input_y_dim, input_x_dim),
                                                   mode=interp_mode) * saliency

                    if depth > 1:
                        local_saliency = torch.max(local_saliency)
                    else:
                        local_saliency = 0

                    # If salience of region is greater than the average, generate higher resolution mask
                    if local_saliency >= threshold:

                        masks_list.append(abs(mask - 1))

                        if perturbation_type == 'blur':

                            b_image = input.clone()
                            b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim,
                                    x1 * x_cell_dim:x2 *
                                    x_cell_dim] = pre_b_image[:, :, y1 *
                                                              y_cell_dim:y2 *
                                                              y_cell_dim, x1 *
                                                              x_cell_dim:x2 *
                                                              x_cell_dim]
                            b_list.append(b_image)

                        if perturbation_type == 'mean':
                            b_image = input.clone()
                            mean = torch.mean(
                                b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim,
                                        x1 * x_cell_dim:x2 * x_cell_dim],
                                axis=(-1, -2),
                                keepdims=True)

                            b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim,
                                    x1 * x_cell_dim:x2 * x_cell_dim] = mean
                            b_list.append(b_image)

            num_masks = len(masks_list)
            print('Selected {}/{} masks at depth {}'.format(
                num_masks, pos_masks, depth))

            if num_masks == 0:
                depth -= 1
                break
            total_masks += num_masks

            while len(masks_list) > 0:
                m_ix = min(len(masks_list), max_batch)
                if perturbation_type != 'fade':
                    b_imgs = torch.cat(b_list[:m_ix])
                    del b_list[:m_ix]
                masks = torch.cat(masks_list[:m_ix])
                del masks_list[:m_ix]

                # resize low-res masks to input size
                masks = F.interpolate(masks, (input_y_dim, input_x_dim),
                                      mode=interp_mode)

                if perturbation_type == 'fade':
                    perturbed_outputs = torch.relu(
                        output -
                        F.softmax(model(input * masks), dim=1)[:, target])
                else:
                    perturbed_outputs = torch.relu(
                        output - F.softmax(model(b_imgs), dim=1)[:, target])

                sal = perturbed_outputs * torch.abs(masks.transpose(0, 1) - 1)
                saliency += torch.sum(sal, dim=(0, 1))
                depth_saliencies[depth] += torch.sum(sal, dim=(0, 1))

                if vis:
                    clear_output(wait=True)
                    print('Saving image...')
                    plt.figure(figsize=(8, 4))
                    plt.subplot(1, 2, 1)
                    plt.title(
                        'Depth: {}, {} x {} Mask\nThreshold: {:.1f}'.format(
                            depth, num_cells, num_cells, threshold))
                    if perturbation_type == 'fade':
                        imsc((input * masks)[0])
                    else:
                        imsc(b_imgs[0])
                    plt.subplot(1, 2, 2)
                    imsc(saliency[0])
                    plt.show()
                    plt.savefig('data/attribution_benchmarks/preview')
                    plt.close()

        print('Used {} masks in total.'.format(total_masks))
        if resize is not None:
            saliency = F.interpolate(saliency, (resize[1], resize[0]),
                                     mode=interp_mode)
        return saliency, depth_saliencies, total_masks
Beispiel #7
0
        def decnn(m, x, i):
            saliency = deconvnet(m, x, i)
            saliency = imsc(saliency[0], quiet=True)[0][None]

            return saliency
Beispiel #8
0
        def excite_l1(m, x, i):
            saliency = excitation_backprop(
                m, x, i, saliency_layer=model.encoder_q.layer1)
            saliency = imsc(saliency[0], quiet=True)[0][None]

            return saliency
Beispiel #9
0
    def do_saliency(x, k, one_vs_all=False):
        head = torch.nn.Linear(k.shape[-1], k.shape[-2], bias=False)
        head.weight.data[:] = k.data[:]
        head = head.cuda()

        contrast_model = torch.nn.Sequential(model.encoder_q, L2Normalize(),
                                             head, torch.nn.Softmax(dim=-1))

        x = x.cuda()

        def colorize(saliency):
            resized = cv2.resize(saliency.permute(0, 2, 3,
                                                  1)[0].detach().cpu().numpy(),
                                 (x.shape[-2], x.shape[-1]),
                                 interpolation=cv2.INTER_LANCZOS4) * 255

            # import pdb; pdb.set_trace() #* 255.

            # import pdb; pdb.set_trace()
            return color(resized)[..., :3]
            # return np.stack([resized]*3, axis=-1)
            # return np.stack([resized]*3, dim=-1)

        def _saliency(x, func):
            s = []
            for i in range(x.shape[0]):
                saliency = func(contrast_model, x[i][None].cuda(), i)
                # import pdb; pdb.set_trace()
                saliency = colorize(saliency)
                s.append(saliency)
            return np.stack(s).transpose(0, -1, 1, 2)

        def gcam(m, x, i):
            g = grad_cam(m, x, i, saliency_layer=model.encoder_q.layer4)
            return g

        def gcam_l3(m, x, i):
            return grad_cam(m, x, i, saliency_layer=model.encoder_q.layer3)

        def excite_c1(m, x, i):
            saliency = excitation_backprop(
                m, x, i, saliency_layer=model.encoder_q.conv1)
            saliency = imsc(saliency[0], quiet=True)[0][None]

            return saliency

        def excite_l1(m, x, i):
            saliency = excitation_backprop(
                m, x, i, saliency_layer=model.encoder_q.layer1)
            saliency = imsc(saliency[0], quiet=True)[0][None]

            return saliency

        def ceb(m, x, i):
            # Contrastive excitation backprop.
            return contrastive_excitation_backprop(
                m,
                x,
                i,
                saliency_layer=model.encoder_q.layer2[-1],
                contrast_layer=model.encoder_q.layer4[-1],
                classifier_layer=model.encoder_q.fc[-1])

            # import excitationbp as eb

        def ep(area):
            def _ep(m, x, i):
                # Extremal perturbation backprop.
                masks_1, _ = extremal_perturbation(
                    m,
                    x,
                    i,
                    reward_func=contrastive_reward,
                    debug=True,
                    areas=[area],
                )
                masks_1 = imsc(masks_1[0], quiet=True)[0][None]

                return masks_1

            return _ep

        def decnn(m, x, i):
            saliency = deconvnet(m, x, i)
            saliency = imsc(saliency[0], quiet=True)[0][None]

            return saliency

        if one_vs_all:
            '''
            compare first x to all k
            '''
            import time

            assert x.shape[0] == 1, 'expected just one instance'

            t0 = time.time()
            rise_saliency = rise(contrast_model, x).transpose(0, 1)
            print('rise took ************', time.time() - t0)

            rise_saliency = torch.stack(
                [imsc(rs, quiet=True)[0] for rs in rise_saliency])
            # import pdb; pdb.set_trace()

            # rise_saliency = np.stack([colorize(rs[None]) for rs in rise_saliency])
            rise_saliency = np.stack([rise_saliency.detach().cpu().numpy()] *
                                     3,
                                     axis=-1)[:, 0]
            # import pdb; pdb.set_trace()

            rise_saliency = rise_saliency.transpose(0, -1, 1, 2)

            # gcam_saliency =

            return dict(
                # rise=rise_saliency,
                # grad_cam=_saliency(torch.cat([x]*k.shape[0]), gcam),
            )

        else:
            out = dict(
                # grad_cam=_saliency(x, gcam),
                # grad_cam_l3=_saliency(x, gcam_l3),
                contrastive_excitation_backprop=_saliency(x, ceb),

                # excite_c1=_saliency(x, excite_c1),
                # excite_l1=_saliency(x, excite_l1),
                # deconvnet=_saliency(x, decnn),
                # rise=rise_saliency
            )

            for k_ep in [0.05]:  #, 0.05, 0.12]:
                out['extremal_perturbation_%s' % k_ep] = _saliency(x, ep(k_ep))

            return out
Beispiel #10
0
    def __next__(self):
        self._lazy_init()
        x, y = next(self.data_iterator)
        torch.manual_seed(self.seed)

        if self.log:
            from torchray.benchmark.logging import mongo_load, mongo_save, \
                data_from_mongo, data_to_mongo

        try:
            assert len(x) == 1
            x = x.to(self.device)
            class_ids = self.data.as_class_ids(y[0])
            image_size = self.data.as_image_size(y[0])

            results = {'pointing': {}, 'pointing_difficult': {}}
            info = {}
            rise_saliency = None

            for class_id in class_ids:

                # Try to recover this result from the log.
                if self.log > 0:
                    image_name = self.data.as_image_name(y[0])
                    data = mongo_load(
                        self.db,
                        self.experiment.name,
                        f"{image_name}-{class_id}",
                    )
                    if data is not None:
                        data = data_from_mongo(data)
                        results['pointing'][class_id] = data['pointing']
                        results['pointing_difficult'][class_id] = data[
                            'pointing_difficult']
                        if self.debug:
                            print(f'{image_name}-{class_id} loaded from log')
                        continue

                # TODO(av): should now be obsolete
                if x.grad is not None:
                    x.grad.data.zero_()

                if self.experiment.method == "center":
                    w, h = image_size
                    point = torch.tensor([[w / 2, h / 2]])

                elif self.experiment.method == "gradient":
                    saliency = gradient(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "deconvnet":
                    saliency = deconvnet(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "guided_backprop":
                    saliency = guided_backprop(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "grad_cam":
                    saliency = grad_cam(
                        self.model,
                        x,
                        class_id,
                        saliency_layer=self.gradcam_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "excitation_backprop":
                    saliency = excitation_backprop(
                        self.model,
                        x,
                        class_id,
                        self.saliency_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "contrastive_excitation_backprop":
                    saliency = contrastive_excitation_backprop(
                        self.model,
                        x,
                        class_id,
                        saliency_layer=self.saliency_layer,
                        contrast_layer=self.contrast_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "rise":
                    # For RISE, compute saliency map for all classes.
                    if rise_saliency is None:
                        rise_saliency = rise(self.model,
                                             x,
                                             resize=image_size,
                                             seed=self.seed)
                    saliency = rise_saliency[:, class_id, :, :].unsqueeze(1)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "extremal_perturbation":

                    if self.experiment.dataset == 'voc_2007':
                        areas = [0.025, 0.05, 0.1, 0.2]
                    else:
                        areas = [0.018, 0.025, 0.05, 0.1]

                    if self.experiment.boom:
                        raise RuntimeError("BOOM!")

                    mask, energy = elp.extremal_perturbation(
                        self.model,
                        x,
                        class_id,
                        areas=areas,
                        num_levels=8,
                        step=7,
                        sigma=7 * 3,
                        max_iter=800,
                        debug=self.debug > 0,
                        jitter=True,
                        smooth=0.09,
                        resize=image_size,
                        perturbation='blur',
                        reward_func=elp.simple_reward,
                        variant=elp.PRESERVE_VARIANT,
                    )

                    saliency = mask.sum(dim=0, keepdim=True)
                    point = _saliency_to_point(saliency)

                    info = {
                        'saliency': saliency,
                        'mask': mask,
                        'areas': areas,
                        'energy': energy
                    }

                else:
                    assert False

                if False:
                    plt.figure()
                    plt.subplot(1, 2, 1)
                    imsc(saliency[0])
                    plt.plot(point[0, 0], point[0, 1], 'ro')
                    plt.subplot(1, 2, 2)
                    imsc(x[0])
                    plt.pause(0)

                results['pointing'][class_id] = self.pointing.evaluate(
                    y[0], class_id, point[0])
                results['pointing_difficult'][
                    class_id] = self.pointing_difficult.evaluate(
                        y[0], class_id, point[0])

                if self.log > 0:
                    image_name = self.data.as_image_name(y[0])
                    mongo_save(
                        self.db, self.experiment.name,
                        f"{image_name}-{class_id}",
                        data_to_mongo({
                            'image_name':
                            image_name,
                            'class_id':
                            class_id,
                            'pointing':
                            results['pointing'][class_id],
                            'pointing_difficult':
                            results['pointing_difficult'][class_id],
                        }))

                if self.log > 1:
                    mongo_save(self.db,
                               str(self.experiment.name) + "-details",
                               f"{image_name}-{class_id}", data_to_mongo(info))

            return results

        except Exception as ex:
            raise ProcessingError(self, self.experiment, self.model, x, y,
                                  class_id, image_size) from ex
def extremal_perturbation(model,
                          input,
                          target,
                          areas=[0.1],
                          perturbation=BLUR_PERTURBATION,
                          max_iter=800,
                          num_levels=8,
                          step=7,
                          sigma=21,
                          jitter=True,
                          variant=PRESERVE_VARIANT,
                          print_iter=None,
                          debug=False,
                          reward_func=simple_reward,
                          resize=False,
                          resize_mode='bilinear',
                          smooth=0):
    r"""Compute a set of extremal perturbations.

    The function takes a :attr:`model`, an :attr:`input` tensor :math:`x`
    of size :math:`1\times C\times H\times W`, and a :attr:`target`
    activation channel. It produces as output a
    :math:`K\times C\times H\times W` tensor where :math:`K` is the number
    of specified :attr:`areas`.

    Each mask, which has approximately the specified area, is searched
    in order to maximise the (spatial average of the) activations
    in channel :attr:`target`. Alternative objectives can be specified
    via :attr:`reward_func`.

    Args:
        model (:class:`torch.nn.Module`): model.
        input (:class:`torch.Tensor`): input tensor.
        target (int): target channel.
        areas (float or list of floats, optional): list of target areas for saliency
            masks. Defaults to `[0.1]`.
        perturbation (str, optional): :ref:`ep_perturbations`.
        max_iter (int, optional): number of iterations for optimizing the masks.
        num_levels (int, optional): number of buckets with which to discretize
            and linearly interpolate the perturbation
            (see :class:`Perturbation`). Defaults to 8.
        step (int, optional): mask step (see :class:`MaskGenerator`).
            Defaults to 7.
        sigma (float, optional): mask smoothing (see :class:`MaskGenerator`).
            Defaults to 21.
        jitter (bool, optional): randomly flip the image horizontally at each iteration.
            Defaults to True.
        variant (str, optional): :ref:`ep_variants`. Defaults to
            :attr:`PRESERVE_VARIANT`.
        print_iter (int, optional): frequency with which to print losses.
            Defaults to None.
        debug (bool, optional): If True, generate debug plots.
        reward_func (function, optional): function that generates reward tensor
            to backpropagate.
        resize (bool, optional): If True, upsamples the masks the same size
            as :attr:`input`. It is also possible to specify a pair
            (width, height) for a different size. Defaults to False.
        resize_mode (str, optional): Upsampling method to use. Defaults to
            ``'bilinear'``.
        smooth (float, optional): Apply Gaussian smoothing to the masks after
            computing them. Defaults to 0.

    Returns:
        A tuple containing the masks and the energies.
        The masks are stored as a :class:`torch.Tensor`
        of dimension
    """
    if isinstance(areas, float):
        areas = [areas]
    momentum = 0.9
    learning_rate = 0.01
    regul_weight = 300
    device = input.device

    regul_weight_last = max(regul_weight / 2, 1)

    if debug:
        print(f"extremal_perturbation:\n"
              f"- target: {target}\n"
              f"- areas: {areas}\n"
              f"- variant: {variant}\n"
              f"- max_iter: {max_iter}\n"
              f"- step/sigma: {step}, {sigma}\n"
              f"- image size: {list(input.shape)}\n"
              f"- reward function: {reward_func.__name__}")

    # Disable gradients for model parameters.
    # TODO(av): undo on leaving the function.
    for p in model.parameters():
        p.requires_grad_(False)

    # Get the perturbation operator.
    # The perturbation can be applied at any layer of the network (depth).
    perturbation = Perturbation(input,
                                num_levels=num_levels,
                                type=perturbation).to(device)

    perturbation_str = '\n  '.join(perturbation.__str__().split('\n'))
    if debug:
        print(f"- {perturbation_str}")

    # Prepare the mask generator.
    shape = perturbation.pyramid.shape[2:]
    mask_generator = MaskGenerator(shape, step, sigma).to(device)
    h, w = mask_generator.shape_in
    pmask = torch.ones(len(areas), 1, h, w).to(device)
    if debug:
        print(f"- mask resolution:\n  {pmask.shape}")

    # Prepare reference area vector.
    max_area = np.prod(mask_generator.shape_out)
    reference = torch.ones(len(areas), max_area).to(device)
    for i, a in enumerate(areas):
        reference[i, :int(max_area * (1 - a))] = 0

    # Initialize optimizer.
    optimizer = optim.SGD([pmask],
                          lr=learning_rate,
                          momentum=momentum,
                          dampening=momentum)
    hist = torch.zeros((len(areas), 2, 0))

    for t in range(max_iter):
        pmask.requires_grad_(True)

        # Generate the mask.
        mask_, mask = mask_generator.generate(pmask)

        # Apply the mask.
        if variant == DELETE_VARIANT:
            x = perturbation.apply(1 - mask_)
        elif variant == PRESERVE_VARIANT:
            x = perturbation.apply(mask_)
        elif variant == DUAL_VARIANT:
            x = torch.cat((
                perturbation.apply(mask_),
                perturbation.apply(1 - mask_),
            ),
                          dim=0)
        else:
            assert False

        # Apply jitter to the masked data.
        if jitter and t % 2 == 0:
            x = torch.flip(x, dims=(3, ))

        # Evaluate the model on the masked data.
        y = model(x)

        # Get reward.
        reward = reward_func(y, target, variant=variant)

        # Reshape reward and average over spatial dimensions.
        reward = reward.reshape(len(areas), -1).mean(dim=1)

        # Area regularization.
        mask_sorted = mask.reshape(len(areas), -1).sort(dim=1)[0]
        regul = -((mask_sorted - reference)**2).mean(dim=1) * regul_weight
        energy = (reward + regul).sum()

        # Gradient step.
        optimizer.zero_grad()
        (-energy).backward()
        optimizer.step()

        pmask.data = pmask.data.clamp(0, 1)

        # Record energy.
        hist = torch.cat((hist,
                          torch.cat((reward.detach().cpu().view(
                              -1, 1, 1), regul.detach().cpu().view(-1, 1, 1)),
                                    dim=1)),
                         dim=2)

        # Adjust the regulariser/area constraint weight.
        regul_weight *= 1.0035

        # Diagnostics.
        debug_this_iter = debug and (t in (0, max_iter - 1)
                                     or regul_weight / regul_weight_last >= 2)

        if (print_iter is not None and t % print_iter == 0) or debug_this_iter:
            print("[{:04d}/{:04d}]".format(t + 1, max_iter), end="")
            for i, area in enumerate(areas):
                print(" [area:{:.2f} loss:{:.2f} reg:{:.2f}]".format(
                    area, hist[i, 0, -1], hist[i, 1, -1]),
                      end="")
            print()

        if debug_this_iter:
            regul_weight_last = regul_weight
            for i, a in enumerate(areas):
                plt.figure(i, figsize=(20, 6))
                plt.clf()
                ncols = 4 if variant == DUAL_VARIANT else 3
                plt.subplot(1, ncols, 1)
                plt.plot(hist[i, 0].numpy())
                plt.plot(hist[i, 1].numpy())
                plt.plot(hist[i].sum(dim=0).numpy())
                plt.legend(('energy', 'regul', 'both'))
                plt.title(f'target area:{a:.2f}')
                plt.subplot(1, ncols, 2)
                imsc(mask[i], lim=[0, 1])
                plt.title(f"min:{mask[i].min().item():.2f}"
                          f" max:{mask[i].max().item():.2f}"
                          f" area:{mask[i].sum() / mask[i].numel():.2f}")
                plt.subplot(1, ncols, 3)
                imsc(x[i])
                if variant == DUAL_VARIANT:
                    plt.subplot(1, ncols, 4)
                    imsc(x[i + len(areas)])
                plt.pause(0.001)

    mask_ = mask_.detach()

    # Resize saliency map.
    mask_ = resize_saliency(input, mask_, resize, mode=resize_mode)

    # Smooth saliency map.
    if smooth > 0:
        mask_ = imsmooth(mask_,
                         sigma=smooth * min(mask_.shape[2:]),
                         padding_mode='constant')

    return mask_, hist