Exemple #1
0
def interpret_model(originalPath='',reconPath='', origOutput='', reconOutput=''):
    # read the images
    print("reading input image")
    original_image = cv.imread(originalPath)
    original_image = cv.cvtColor(original_image, cv.COLOR_BGR2RGB)
    recon_image = cv.imread(reconPath)
    recon_image = cv.cvtColor(recon_image, cv.COLOR_BGR2RGB)
    
    # creat torch tensor
    input = Image.open(originalPath)
    input = data_transform(input)
    input = torch.unsqueeze(input, 0)
    input.requires_grad = True
    recon = Image.open(reconPath)
    recon = data_transform(recon)
    recon = torch.unsqueeze(recon, 0)
    recon.requires_grad = True

    # do the classfication on the original image
    original_label_float = model(input.cuda(0))
    _, target_label = torch.max(original_label_float, 1)
    recon_label_float = model(recon.cuda(0))
    _, recon_label = torch.max(recon_label_float, 1)
    saliency = Saliency(model)
    grads = saliency.attribute(input.cuda(0), target = target_label)
    grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))
    saliencyMap = viz.visualize_image_attr(grads, original_image, method="blended_heat_map", sign="all",
                            show_colorbar=True, title="Overlayed Saliency Map - Original")
    plt.savefig(origOutput + '/saliency_' + ntpath.basename(originalPath))

    grads = saliency.attribute(recon.cuda(0), target = recon_label)
    grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))
    saliencyMap = viz.visualize_image_attr(grads, recon_image, method="blended_heat_map", sign="all",
                            show_colorbar=True, title="Overlayed Saliency Map - Recon")
    plt.savefig(reconOutput + '/saliency_' + ntpath.basename(reconPath))
Exemple #2
0
def saliency_map(
        batch: dict,
        saliency: Saliency,
        sign: str = 'all',
        method: str = 'blended_heat_map',
        use_pyplot: bool = False,
        fig_axis: tuple = None,
        mix_bg: bool = True,
        alpha_overlay: float = 0.7,
) -> Tuple[Any, torch.Tensor]:
    """
    :param batch: batch to visualise
    :param saliency: Saliency object initialised for trainer_module
    :param sign: sign of gradient attributes to visualise
    :param method: method of visualization to be used
    :param use_pyplot: whether to use pyplot
    :param mix_bg: whether to mix semantic/aerial map with vehicles
    :return: pair of figure and corresponding gradients tensor
    """

    batch['image'].requires_grad = True
    grads = saliency.attribute(batch['image'], abs=False, additional_forward_args=(
        batch['target_positions'], None if 'target_availabilities' not in batch else batch['target_availabilities'],
        False))
    batch['image'].requires_grad = False
    gradsm = grads.squeeze().cpu().detach().numpy()
    if len(gradsm.shape) == 3:
        gradsm = gradsm.reshape(1, *gradsm.shape)
    gradsm = np.transpose(gradsm, (0, 2, 3, 1))
    im = batch['image'].detach().cpu().numpy().transpose(0, 2, 3, 1)
    fig, axis = fig_axis if fig_axis is not None else plt.subplots(2 - mix_bg, im.shape[0], dpi=200, figsize=(6, 6))
    for b in range(im.shape[0]):
        if mix_bg:
            grad_norm = float(np.abs(gradsm[b, ...]).sum())
            viz.visualize_image_attr(
                gradsm[b, ...], im[b, ...], method=method,
                sign=sign, use_pyplot=use_pyplot,
                plt_fig_axis=(fig, axis if im.shape[0] == 1 else axis[b]),
                alpha_overlay=alpha_overlay,
                title=f'l1 grad: {grad_norm:.5f}',
            )
            ttl = (axis if im.shape[0] == 1 else axis[b]).title
            ttl.set_position([.5, 0.95])
            (axis if im.shape[0] == 1 else axis[b]).axis('off')
        else:
            for (s_channel, end_channel), row in [((im.shape[-1] - 3, im.shape[-1]), 0), ((0, im.shape[-1] - 3), 1)]:
                grad_norm = float(np.abs(gradsm[b, :, :, s_channel:end_channel]).sum())
                viz.visualize_image_attr(
                    gradsm[b, :, :, s_channel:end_channel], im[b, :, :, s_channel:end_channel], method=method,
                    sign=sign, use_pyplot=use_pyplot,
                    plt_fig_axis=(fig, axis[row] if im.shape[0] == 1 else axis[row][b]),
                    alpha_overlay=alpha_overlay, title=f'l1 grad: {grad_norm:.5f}',
                )
                ttl = (axis[row] if im.shape[0] == 1 else axis[row][b]).title
                ttl.set_position([.5, 0.95])
                (axis[row] if im.shape[0] == 1 else axis[row][b]).axis('off')
    return fig, grads
Exemple #3
0
def display_explanation(explanation, image):
    original_image = np.transpose(image.squeeze().cpu().detach().numpy(),
                                  (1, 2, 0))

    viz.visualize_image_attr(explanation,
                             original_image,
                             method="blended_heat_map",
                             sign="all",
                             show_colorbar=True,
                             title="Overlayed Integrated Gradients")
Exemple #4
0
def explain_pair(alg, pair: List[torch.tensor], labels: List[int], **kwargs):
    """
    Use a Captum explanation algorithm on a pair of images and plot side-by-side

    Parameters:
        alg: Captum algorithm, e.g. Saliency()
        pair: list of 2 images as torch tensors
        labels: the labels for each image
        **kwargs: additional arguments for Captum algorithm
    """

    def _prepare_explainer_input(img):
        input = img.unsqueeze(0)
        input.requires_grad = True
        input = input.cuda()
        return input

    inputs = [_prepare_explainer_input(img) for img in pair]
    grads = [explainer(alg, inp, lab, **kwargs) for inp, lab in zip(inputs, labels)]

    unorm = UnNormalize(img_means, img_stds)
    org_images = [unorm(img) for img in pair]
    org_images = [
        np.transpose(org_img.cpu().detach().numpy(), (1, 2, 0))
        for org_img in org_images
    ]

    fig, (ax1, ax2) = plt.subplots(1, 2)
    _ = viz.visualize_image_attr(
        grads[0],
        org_images[0],
        method="blended_heat_map",
        sign="absolute_value",
        show_colorbar=True,
        title="Predicted",
        plt_fig_axis=(fig, ax1),
        # use_pyplot to false to  avoid viz calling plt.show()
        use_pyplot=False,
    )
    _ = viz.visualize_image_attr(
        grads[1],
        org_images[1],
        method="blended_heat_map",
        sign="absolute_value",
        show_colorbar=True,
        title="Nearest neighbor",
        plt_fig_axis=(fig, ax2),
    )

    return fig, (ax1, ax2)
Exemple #5
0
def main() -> None:
    torch.cuda.empty_cache()
    # Load data
    train_loader, test_loader = load_data()

    # net = training_procedure(train_loader)
    net = joblib.load("saves/save_network_60.pickle")
    net = net.module.to(Constant.device)
    net.eval()

    # test_ensemble(net, test_loader)
    dataiter = iter(test_loader)

    for i in range(100):
        image, label = next(dataiter)

        image = image.to(Constant.device)
        label = label.to(Constant.device)

        attr_ig1 = generate_explanation(net, image, label, 1, True)
        attr_ig20 = generate_explanation(net, image, label, 20, True)
        attr_mean = generate_explanation(net, image, label, 1, False)

        original_image = np.transpose(image.squeeze().cpu().detach().numpy(),
                                      (1, 2, 0))

        fig_0 = viz.visualize_image_attr(None,
                                         original_image,
                                         method="original_image",
                                         title="Original image")[0]
        fig_1 = viz.visualize_image_attr(attr_ig1,
                                         original_image,
                                         method="masked_image",
                                         title="Masked DeepLift - 1")[0]
        fig_3 = \
            viz.visualize_image_attr(attr_ig20, original_image, method="masked_image", title="Masked DeepLift - 20")[0]
        fig_2 = \
        viz.visualize_image_attr(attr_mean, original_image, method="masked_image", title="Masked DeepLift - mean")[
            0]

        fig_0.savefig(f"images/image_{i}_0.png", dpi=fig_0.dpi)
        fig_1.savefig(f"images/image_{i}_1.png", dpi=fig_1.dpi)
        fig_2.savefig(f"images/image_{i}_2.png", dpi=fig_2.dpi)
        fig_3.savefig(f"images/image_{i}_3.png", dpi=fig_3.dpi)
def main(mixup=False):
    prefix = "Mixup" if mixup else "Large"
    run = wandb.init(
        name=f"Interpretability ({prefix})",
        project="ct-interpretability",
        dir=DEFAULT_DATA_STORAGE,
        reinit=True,
    )
    model = get_model(mixup)
    dataset = get_miccai_2d(
        "test",
        transform=DEGREE[model.hparams.transform_degree]["test"],
        enhanced="Boundary" in model.hparams.loss_fx,
    )

    class_labels = dict(zip(range(1, model._n_classes), miccai.STRUCTURES))
    class_labels[0] = "Void"
    step = 0

    for sample in tqdm(dataset):
        preproc_img, masks, _, *others = sample
        normalized_inp = preproc_img.unsqueeze(0).to(device)
        normalized_inp.requires_grad = True
        masks = _squash_masks(masks, 10, masks.device)

        if len(masks.unique()) < 6:
            # Only displaying structures with atleast 5 structures (excluding background)
            continue

        out = model(normalized_inp)
        out_max = _squash_predictions(out).unsqueeze(1)

        log_samples(preproc_img, masks, out_max, class_labels, step)

        def segmentation_wrapper(input):
            return model(input).sum(dim=(2, 3))

        layer = model.unet.model[2][1].conv.unit0.conv
        lgc = LayerGradCam(segmentation_wrapper, layer)

        figures = []
        for structure in miccai.STRUCTURES:
            idx = structures.index(structure)
            gc_attr = lgc.attribute(normalized_inp, target=idx)
            fig, ax = viz.visualize_image_attr(
                gc_attr[0].cpu().permute(1, 2, 0).detach().numpy(),
                sign="all",
                use_pyplot=False,
            )
            ax.set_title(structure)
            figures.append(wandb.Image(fig))

        wandb.log({"GradCam Attributions": figures}, step=step)
        step += 1

    run.finish()
Exemple #7
0
 def draw_integrated_gradients(self, toplot_img, attributions_ig, fig, ax):
     _ = viz.visualize_image_attr(attributions_ig,
                                  toplot_img,
                                  'blended_heat_map',
                                  'positive',
                                  plt_fig_axis=(fig, ax),
                                  cmap=self.default_cmap,
                                  show_colorbar=False,
                                  use_pyplot=False,
                                  outlier_perc=2)
     ax.title.set_text("Integrated gradient")
Exemple #8
0
 def draw_occlusion(self, toplot_img, attributions_occ, fig, ax):
     _ = viz.visualize_image_attr(attributions_occ,
                                  toplot_img,
                                  'blended_heat_map',
                                  'positive',
                                  plt_fig_axis=(fig, ax),
                                  show_colorbar=False,
                                  cmap=self.default_cmap,
                                  outlier_perc=2,
                                  use_pyplot=False)
     ax.title.set_text("occlusion")
Exemple #9
0
 def draw_saliency(self, toplot_img, attributions_sa, fig, ax):
     _ = viz.visualize_image_attr(attributions_sa,
                                  toplot_img,
                                  'blended_heat_map',
                                  'absolute_value',
                                  plt_fig_axis=(fig, ax),
                                  cmap=self.default_cmap,
                                  show_colorbar=False,
                                  use_pyplot=False,
                                  outlier_perc=2)
     ax.title.set_text("Gradient saliency")
Exemple #10
0
def get_integrated_gradients_attribution_with_prediction(
        model: nn.Module, image: np.ndarray):
    torch.manual_seed(0)
    np.random.seed(0)

    transform_normalize = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    input_tensor = transform_normalize(image)
    input_tensor = input_tensor.unsqueeze(0)

    print("Performing forward pass...")
    model = model.eval()
    output = model(input_tensor)
    output = F.softmax(output, dim=1)
    _, pred_label_idx = torch.topk(output, 1)

    pred_label_idx.squeeze_()
    prediction = pred_label_idx.item()

    print("Generating visualization...")
    tic = time.time()
    visualization = IntegratedGradients(model)
    attributions = visualization.attribute(input_tensor,
                                           target=pred_label_idx,
                                           n_steps=5)
    attributions = attributions.squeeze().cpu().detach().numpy()
    attributions = np.transpose(attributions, (1, 2, 0))
    toc = time.time()
    print("Took %.02fs" % (toc - tic))

    custom_cmap = LinearSegmentedColormap.from_list("custom blue",
                                                    [(0, "#ffffff"),
                                                     (0.25, "#000000"),
                                                     (1, "#000000")],
                                                    N=256)

    fig, ax = viz.visualize_image_attr(
        attributions,
        method="heat_map",
        sign="positive",
        cmap=custom_cmap,
        show_colorbar=True,
    )

    ax.margins(0)
    fig.tight_layout(pad=0)

    return fig, prediction
Exemple #11
0
    def get_insights(self, tensor_data, _, target=0):
        # TODO: this will work with Image Classification, but needs work for segmentation
        all_attr = self.ig.attribute(tensor_data, target=target, n_steps=15)
        n, c, h, w = all_attr.size()
        reshape_attr = all_attr.view(n, h, w, c).detach().cpu().numpy()
        return_bytes = []
        for attr in reshape_attr:
            matplot_viz, _ = viz.visualize_image_attr(attr,
                                                      use_pyplot=False,
                                                      sign='all')

            fout = BytesIO()
            matplot_viz.savefig(fout)
            return_bytes.append(fout.getvalue())

        return return_bytes
 def visualize(self,
               attributions: torch.Tensor,
               images: torch.Tensor,
               titles=()):
     attr_dl = np.transpose(attributions.cpu().detach().numpy(),
                            (0, 2, 3, 1))
     images_np = np.transpose((images.cpu().detach().numpy()), (0, 2, 3, 1))
     for idx, (img, attr) in enumerate(zip(images_np, attr_dl), 0):
         if len(titles) == 0:
             title_ = f"Attribution for {self.xai_algorithm.__class__.__name__} {idx + 1}"
         else:
             title_ = titles[idx]
         _ = viz.visualize_image_attr(attr,
                                      img,
                                      method="blended_heat_map",
                                      sign="all",
                                      show_colorbar=True,
                                      title=title_)
Exemple #13
0
def save_explanation(inputImage: torch.Tensor, modeladapter: torch.nn.Module,
                     cfg: DictConfig, pred_label_idx: int, pred_label_num: int,
                     gt_label_num: int, filename: str, filepath: str,
                     filename_without_ext: str, prediction_score: float):
    """
    Return explanation value dict
    """

    input_gradients = torch.unsqueeze(inputImage, 0)

    integrated_gradients = IntegratedGradients(modeladapter)
    default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                     [(0, '#ffffff'),
                                                      (0.25, '#000000'),
                                                      (1, '#000000')],
                                                     N=256)

    noise_tunnel = NoiseTunnel(integrated_gradients)
    attributions_ig = noise_tunnel.attribute(
        input_gradients,
        nt_samples=cfg.inference.captum.noise_tunnel.nt_samples,
        nt_samples_batch_size=cfg.inference.captum.noise_tunnel.
        nt_samples_batch_size,
        nt_type=cfg.inference.captum.noise_tunnel.nt_type,
        target=pred_label_idx)

    # Standard Captum Visualization
    figure, plot = viz.visualize_image_attr(np.transpose(
        attributions_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)),
                                            method="heat_map",
                                            sign="positive",
                                            cmap=default_cmap,
                                            show_colorbar=True)

    dict_col_name = {}

    dict_col_name.update({
        "pred": pred_label_num,
        "GT": gt_label_num,
        "predict_score": prediction_score.squeeze_().item(),
        "image_path": filename
    })

    save_path_original = "/PREDcls_" + str(pred_label_num) + "_GTcls_" + str(
        gt_label_num) + "_" + filename

    save_path_explanation = "/PREDcls_" + str(
        pred_label_num) + "_GTcls_" + str(
            gt_label_num) + "_" + filename_without_ext + "_explain" + ".png"

    if (pred_label_num == gt_label_num):

        # path to image

        shutil.copy(
            filepath,
            cfg.inference.captum.correct_explanation_path + save_path_original)

        figure.savefig(cfg.inference.captum.correct_explanation_path +
                       save_path_explanation,
                       dpi=figure.dpi)

    else:

        shutil.copy(
            filepath,
            cfg.inference.captum.error_explanation_path + save_path_original)

        figure.savefig(cfg.inference.captum.error_explanation_path +
                       save_path_explanation,
                       dpi=figure.dpi)

    plt.close()

    return dict_col_name
    target_class = torch.tensor([1])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    img_list = glob.glob(data_dir + '/MSIMUT/*.png')
    original_image = np.array(Image.open(img_list[args.idx])).astype('uint8')
    ref = get_ref(original_image, args.reference)

    modelnames = [
        'style transfer (STRAP)', 'stain augmentation (SA)',
        'stain normalization (SN)'
    ]

    fig, ax = plt.subplots(1, 4, figsize=(24, 6))
    _, ax[0] = viz.visualize_image_attr(None,
                                        original_image,
                                        method="original_image",
                                        title="Original Image",
                                        plt_fig_axis=(fig, ax[0]),
                                        use_pyplot=False)
    for i, experiment in enumerate(
        ['style_transfer', 'stain_augmentation', 'stain_normalization']):
        model_path = state_dict_dir + '/' + experiment + '_' + str(
            args.kfold) + '.pth'
        model = get_model(model_path, num_classes)
        model.eval()

        path2hdf5_msi = low_data_dir + f'/crc-dx-test_msi_low_{args.radius}.hdf5'
        h5_low = h5py.File(path2hdf5_msi)
        low = h5_low['img'][args.idx]
        low = transform(low).unsqueeze(0)
        low.requires_grad = True
Exemple #15
0
def plot_original_and_explained_pair(
    pair: List[torch.tensor],
    labels: List[int],
    alg,
    pred: int,
    savename: str = None,
    method: str = "blended_heat_map",
    sign: str = "absolute_value",
):
    """
    Plot 2x2 grid of images. First row shows original images, second the gradient explanations.

    Args:
        pair (List[torch.tensor]): List of 2 images as torch tensors
        labels (List[int]): the true labels for the images
        alg ([type]): a Captum algorithm
        pred (int): the prediction for the original image
        savename (str, optional): If given, saves the image to disk. Defaults to None.
        method (str, optional): which visualization method to use (heat_map, blended_heat_map, original_image, masked_image, alpha_scaling)
        sign (str, optional): sign of attributions to visualiuze (positive, absolute_value, negative, all)
    """

    def _prepare_explainer_input(img):
        input = img.unsqueeze(0)
        input.requires_grad = True
        input = input.cuda()
        return input

    inputs = [_prepare_explainer_input(img) for img in pair]
    # Explaining the target
    grads_target = [explainer(alg, inp, lab) for inp, lab in zip(inputs, labels)]
    # Explaining the actual prediction
    grads_pred = [explainer(alg, inp, pred) for inp in inputs]

    unorm = UnNormalize(img_means, img_stds)
    org_images = [unorm(img) for img in pair]
    org_images = [
        np.transpose(org_img.cpu().detach().numpy(), (1, 2, 0))
        for org_img in org_images
    ]

    text_labels = [classes[l] for l in labels]  # get text label

    fig, axes = plt.subplots(2, 3)
    # plt.subplots_adjust(wspace=0.0001)
    ### Plot original images
    # Wrongly predicted
    _ = viz.visualize_image_attr(
        grads_target[0],
        org_images[0],
        method="original_image",
        title=f"true: {text_labels[0]}, pred: {classes[pred]}",
        plt_fig_axis=(fig, axes[0, 0]),
        use_pyplot=False,
    )

    # Nearest neighbor
    _ = viz.visualize_image_attr(
        grads_target[1],
        org_images[1],
        method="original_image",
        title=f"nn: {text_labels[1]}",
        plt_fig_axis=(fig, axes[1, 0]),
        use_pyplot=False,
    )

    ### Gradient explanations for predicted
    _ = viz.visualize_image_attr(
        grads_pred[0],
        org_images[0],
        method=method,
        sign=sign,  # org: "absolute_value"
        show_colorbar=True,
        title=f"Exp. wrt. {classes[pred]}",
        plt_fig_axis=(fig, axes[0, 1]),
        # use_pyplot to false to  avoid viz calling plt.show()
        use_pyplot=False,
    )
    _ = viz.visualize_image_attr(
        grads_pred[1],
        org_images[1],
        method=method,
        sign=sign,
        show_colorbar=True,
        title="",
        plt_fig_axis=(fig, axes[1, 1]),
        use_pyplot=True,
    )
    ### Gradient explanations for target
    _ = viz.visualize_image_attr(
        grads_target[0],
        org_images[0],
        method=method,
        sign=sign,  # org: "absolute_value"
        show_colorbar=True,
        title=f"Exp. wrt. {text_labels[0]}",
        plt_fig_axis=(fig, axes[0, 2]),
        # use_pyplot to false to  avoid viz calling plt.show()
        use_pyplot=False,
    )
    _ = viz.visualize_image_attr(
        grads_target[1],
        org_images[1],
        method=method,
        sign=sign,
        show_colorbar=True,
        title="",
        plt_fig_axis=(fig, axes[1, 2]),
        use_pyplot=False,
    )

    if savename is not None:
        plt.savefig(savename + ".png")

    plt.close()
    return fig, axes
Exemple #16
0
def visualize_maps(
        model: torch.nn.Module,
        inputs: Union[Tuple[torch.Tensor, torch.Tensor]],
        labels: torch.Tensor,
        title: str,
        second_occlusion: Tuple[int, int, int] = (1, 2, 2),
        baselines: Tuple[int, int] = (0, 0),
        closest: bool = False,
) -> None:
    """
    Visualizes the average of the inputs, or the single input, using various different XAI approaches
    """
    single = inputs[1].ndim == 2
    model.zero_grad()
    model.eval()
    occ = Occlusion(model)
    saliency = Saliency(model)
    saliency = NoiseTunnel(saliency)
    igrad = IntegratedGradients(model)
    igrad_2 = NoiseTunnel(igrad)
    # deep_lift = DeepLift(model)
    grad_shap = ShapleyValueSampling(model)
    output = model(inputs[0], inputs[1])
    output = F.softmax(output, dim=-1).argmax(dim=1, keepdim=True)
    labels = F.softmax(labels, dim=-1).argmax(dim=1, keepdim=True)
    if np.all(labels.cpu().numpy() == 1) and not closest:
        return
    if True:
        targets = labels
    else:
        targets = output
    print(targets)
    correct = targets.cpu().numpy() == labels.cpu().numpy()
    # if correct:
    #   return
    occ_out = occ.attribute(
        inputs,
        baselines=baselines,
        sliding_window_shapes=((1, 5, 5), second_occlusion),
        target=targets,
    )
    # occ_out2 = occ.attribute(inputs, sliding_window_shapes=((1,20,20), second_occlusion), strides=(8,1), target=targets)
    saliency_out = saliency.attribute(inputs,
                                      nt_type="smoothgrad_sq",
                                      n_samples=5,
                                      target=targets,
                                      abs=False)
    # igrad_out = igrad.attribute(inputs, target=targets, internal_batch_size=1)
    igrad_out = igrad_2.attribute(
        inputs,
        baselines=baselines,
        target=targets,
        n_samples=5,
        nt_type="smoothgrad_sq",
        internal_batch_size=1,
    )
    # deep_lift_out = deep_lift.attribute(inputs, target=targets)
    grad_shap_out = grad_shap.attribute(inputs,
                                        baselines=baselines,
                                        target=targets)

    if single:
        inputs = convert_to_image(inputs)
        occ_out = convert_to_image(occ_out)
        saliency_out = convert_to_image(saliency_out)
        igrad_out = convert_to_image(igrad_out)
        # grad_shap_out = convert_to_image(grad_shap_out)
    else:
        inputs = convert_to_image_multi(inputs)
        occ_out = convert_to_image_multi(occ_out)
        saliency_out = convert_to_image_multi(saliency_out)
        igrad_out = convert_to_image_multi(igrad_out)
        grad_shap_out = convert_to_image_multi(grad_shap_out)
    fig, axes = plt.subplots(2, 5)
    (fig, axes[0, 0]) = visualization.visualize_image_attr(
        occ_out[0][0],
        inputs[0][0],
        title="Original Image",
        method="original_image",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 0]),
        use_pyplot=False,
    )
    (fig, axes[0, 1]) = visualization.visualize_image_attr(
        occ_out[0][0],
        None,
        sign="all",
        title="Occ (5x5)",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 1]),
        use_pyplot=False,
    )
    (fig, axes[0, 2]) = visualization.visualize_image_attr(
        saliency_out[0][0],
        None,
        sign="all",
        title="Saliency",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 2]),
        use_pyplot=False,
    )
    (fig, axes[0, 3]) = visualization.visualize_image_attr(
        igrad_out[0][0],
        None,
        sign="all",
        title="Integrated Grad",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 3]),
        use_pyplot=False,
    )
    (fig, axes[0, 4]) = visualization.visualize_image_attr(
        grad_shap_out[0],
        None,
        title="GradSHAP",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 4]),
        use_pyplot=False,
    )
    ##### Second Input Labels #########################################################################################
    (fig, axes[1, 0]) = visualization.visualize_image_attr(
        occ_out[1],
        inputs[1],
        title="Original Aux",
        method="original_image",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 0]),
        use_pyplot=False,
    )
    (fig, axes[1, 1]) = visualization.visualize_image_attr(
        occ_out[1],
        None,
        sign="all",
        title="Occ (1x1)",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 1]),
        use_pyplot=False,
    )
    (fig, axes[1, 2]) = visualization.visualize_image_attr(
        saliency_out[1],
        None,
        sign="all",
        title="Saliency",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 2]),
        use_pyplot=False,
    )
    (fig, axes[1, 3]) = visualization.visualize_image_attr(
        igrad_out[1],
        None,
        sign="all",
        title="Integrated Grad",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 3]),
        use_pyplot=False,
    )
    (fig, axes[1, 4]) = visualization.visualize_image_attr(
        grad_shap_out[1],
        None,
        title="GradSHAP",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 4]),
        use_pyplot=False,
    )

    fig.suptitle(
        title +
        f" Label: {labels.cpu().numpy()} Pred: {targets.cpu().numpy()}")
    plt.savefig(
        f"{title}_{'single' if single else 'multi'}_{'Failed' if correct else 'Success'}_baseline{baselines[0]}.png",
        dpi=300,
    )
    plt.clf()
    plt.cla()
    def vis_explanation(self, number):
        if len(self.explainVis) == 0:
            for i, batch in enumerate(self.test_loader):
                self.explainVis = batch
                break

        # oldIndices = self.test_loader.indices.copy()
        # self.test_loader.indices = self.test_loader.indices[:2]

        # datasetLoader = self.test_loader
        layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2)

        # for i, batch in enumerate(datasetLoader):

        lb = self.explainVis[1].to(device)
        # print(len(lb))
        img = self.explainVis[0].to(device)
        # plt.subplot(2,1,1)
        # plt.imshow(img.squeeze().cpu().numpy())

        pred = self.model(img)
        predlb = torch.argmax(pred, 1)
        imgCQ = img.clone()

        # print('Prediction label is :',predlb.cpu().numpy())
        # print('Ground Truth label is: ',lb.cpu().numpy())
        ##explain to me :
        gc_attr = layer_gc.attribute(imgCQ,
                                     target=predlb,
                                     relu_attributions=False)
        upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64))

        gc_attr = layer_gc.attribute(imgCQ, target=lb, relu_attributions=False)
        upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64))
        if not os.path.exists('./pic'):
            os.mkdir('./pic')

        ####PLot################################################
        plotMe = viz.visualize_image_attr(
            upsampled_attr[7].detach().cpu().numpy().transpose([1, 2, 0]),
            original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]),
            method='heat_map',
            sign='all',
            plt_fig_axis=None,
            outlier_perc=2,
            cmap='inferno',
            alpha_overlay=0.2,
            show_colorbar=True,
            title=str(predlb[7]),
            fig_size=(8, 10),
            use_pyplot=True)

        plotMe[0].savefig('./pic/' + str(number) + 'NotEQPred.jpg')
        ################################################

        plotMe = viz.visualize_image_attr(
            upsampled_attrB[7].detach().cpu().numpy().transpose([1, 2, 0]),
            original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]),
            method='heat_map',
            sign='all',
            plt_fig_axis=None,
            outlier_perc=2,
            cmap='inferno',
            alpha_overlay=0.9,
            show_colorbar=True,
            title=str(lb[7].cpu()),
            fig_size=(8, 10),
            use_pyplot=True)

        plotMe[0].savefig('./pic/' + str(number) + 'NotEQLabel.jpg')
        ################################################

        outImg = img[7].squeeze().detach().cpu().numpy()
        fig2 = plt.figure(figsize=(12, 12))
        prImg = plt.imshow(outImg)
        fig2.savefig('./pic/' + str(number) + 'NotEQOrig.jpg')
        ################################################
        fig = plt.figure(figsize=(15, 10))
        ax = fig.add_subplot(111, projection='3d')

        z = upsampled_attr[7].squeeze().detach().cpu().numpy()
        x = np.arange(0, 64, 1)
        y = np.arange(0, 64, 1)
        X, Y = np.meshgrid(x, y)

        plll = ax.plot_surface(X, Y, z, cmap=cm.coolwarm)
        # Customize the z axis.
        # ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z))
        ax.set_zlim(-0.02, 0.1)
        ax.zaxis.set_major_locator(LinearLocator(10))
        ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

        # Add a color bar which maps values to colors.
        fig.colorbar(plll, shrink=0.5, aspect=5)
        fig.savefig('./pic/' + str(number) + 'NotEQ3D.jpg')
        results = wrapped_model.format_readout()
        norm_grads = []

        for i in range(len(grads)):
            # get the original image
            original_image = np.transpose((imgs[i].cpu().detach().numpy() / 2) + 0.5, (1, 2, 0))
            # get the gradients for this image
            img_grad = grads[i, :, :, :]
            # shape must be h,w,c, move first dim to the end
            img_grad = img_grad.permute((1, 2, 0))
            norm_grad = _normalize_image_attr(img_grad.numpy(), sign="all", outlier_perc=2)
            norm_grads.append(norm_grad)
            if save_png:
                att_score = results['attention_outputs'][0][i]
                title = 'Attention = ' + str(np.round(att_score,2))
                fig, ax = viz.visualize_image_attr(img_grad.numpy(), original_image, method="blended_heat_map",
                                        alpha_overlay=0.5, sign="absolute_value", show_colorbar=False, title=title)
                out_path = os.path.join('../output/images', (str(i) + '.png'))

                fig.savefig(out_path)

        # stack up the activation maps
        grad_output = np.stack(norm_grads)

        assert image.shape == grad_output.shape, 'Output and input size don\'t match'

        grad_output = itk.image_from_array(grad_output)
        raw_name, extension = os.path.splitext(fn)
        new_name = raw_name + '_out' + extension
        os.makedirs('/output/images', exist_ok=True)
        out_path = os.path.join('/output/images', new_name)
Exemple #19
0
                     "actual " + str(act.item()))

    from captum.attr import IntegratedGradients
    from captum.attr import visualization as viz

    from matplotlib.colors import LinearSegmentedColormap

    model.return_attn = False
    integrated_gradients = IntegratedGradients(model)
    attributions_ig = integrated_gradients.attribute(imout.unsqueeze(0),
                                                     target=None,
                                                     n_steps=100)
    default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                     [(0, '#ffffff'),
                                                      (0.25, '#000000'),
                                                      (1, '#000000')],
                                                     N=256)

    _ = viz.visualize_image_attr(
        np.transpose(attributions_ig.squeeze().cpu().detach().numpy(),
                     (1, 2, 0)),
        np.transpose(imout.squeeze().cpu().detach().numpy(), (1, 2, 0)),
        method='heat_map',
        cmap=default_cmap,
        show_colorbar=True,
        sign='positive',
        outlier_perc=1,
        plt_fig_axis=(fig, axs[2]))

    plt.show()
Exemple #20
0
def draw_occlusion(
        batch: dict,
        occlusion: Occlusion,
        window: tuple = (4, 4),
        stride: tuple = (8, 8),
        sign: str = 'positive',
        method: str = 'blended_heat_map',
        use_pyplot: bool = False,
        outlier_perc: float = 2.,
        fig_axis: tuple = None,
        mix_bg: bool = True,
        alpha_overlay: float = 0.7
):
    """
    :param batch: batch to visualise
    :param occlusion: Occlusion object initialised for trainer_module
    :param window: Shape of patch (hyperrectangle) to occlude each input
    :param stride: step by which the occlusion hyperrectangle should be shifted by in each direction
    :param sign: sign of gradient attributes to visualise
    :param method: method of visualization to be used
    :param use_pyplot: whether to use pyplot
    :param mix_bg: whether to mix semantic/aerial map with vehicles
    :return: pair of figure and corresponding gradients tensor
    """

    strides = (batch['image'].shape[2] // stride[0], batch['image'].shape[3] // stride[1])
    window_size = (batch['image'].shape[2] // window[0], batch['image'].shape[3] // window[1])
    channels = batch['image'].shape[1]

    grads = occlusion.attribute(
        batch['image'], strides=(channels if mix_bg else channels - 3, *strides),
        sliding_window_shapes=(channels if mix_bg else channels - 3, *window_size),
        baselines=0,
        additional_forward_args=(
            batch['target_positions'],
            None if 'target_availabilities' not in batch else batch['target_availabilities'],
            False))

    gradsm = grads.squeeze().cpu().detach().numpy()
    if len(gradsm.shape) == 3:
        gradsm = gradsm.reshape(1, *gradsm.shape)
    gradsm = np.transpose(gradsm, (0, 2, 3, 1))
    im = batch['image'].detach().cpu().numpy().transpose(0, 2, 3, 1)
    fig, axis = fig_axis if fig_axis is not None else plt.subplots(2 - mix_bg, im.shape[0], dpi=200, figsize=(6, 6))
    for b in range(im.shape[0]):
        if mix_bg:
            viz.visualize_image_attr(
                gradsm[b, ...], im[b, ...], method=method,
                sign=sign, use_pyplot=use_pyplot,
                plt_fig_axis=(fig, axis if im.shape[0] == 1 else axis[b]),
                alpha_overlay=alpha_overlay, outlier_perc=outlier_perc,
            )
            ttl = (axis if im.shape[0] == 1 else axis[b]).title
            ttl.set_position([.5, 0.95])
            (axis if im.shape[0] == 1 else axis[b]).axis('off')
        else:
            for (s_channel, end_channel), row in [((im.shape[-1] - 3, im.shape[-1]), 0), ((0, im.shape[-1] - 3), 1)]:
                viz.visualize_image_attr(
                    gradsm[b, :, :, s_channel:end_channel], im[b, :, :, s_channel:end_channel], method=method,
                    sign=sign, use_pyplot=use_pyplot,
                    plt_fig_axis=(fig, axis[row] if im.shape[0] == 1 else axis[row][b]),
                    alpha_overlay=alpha_overlay, outlier_perc=outlier_perc,
                )
                ttl = (axis[row] if im.shape[0] == 1 else axis[row][b]).title
                ttl.set_position([.5, 0.95])
                (axis[row] if im.shape[0] == 1 else axis[row][b]).axis('off')
    return fig, grads
Exemple #21
0
    attributions_ig = integrated_gradients.attribute(input,
                                                     target=pred_label_idx,
                                                     n_steps=200,
                                                     internal_batch_size=1)

    default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                     [(0, '#ffffff'),
                                                      (0.25, '#000000'),
                                                      (1, '#000000')],
                                                     N=256)

    vis_img = viz.visualize_image_attr(
        np.transpose(attributions_ig.squeeze().cpu().detach().numpy(),
                     (1, 2, 0)),
        np.transpose(transformed_img.squeeze().cpu().detach().numpy(),
                     (1, 2, 0)),
        method='heat_map',
        cmap=default_cmap,
        show_colorbar=True,
        sign='positive',
        outlier_perc=1)

    noise_tunnel = NoiseTunnel(integrated_gradients)

    attributions_ig_nt = noise_tunnel.attribute(input,
                                                nt_samples=10,
                                                nt_type='smoothgrad_sq',
                                                target=pred_label_idx,
                                                internal_batch_size=10)

    _ = viz.visualize_image_attr_multiple(
        np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(),
Exemple #22
0
attr_dl = attribute_image_features(dl, input, baselines=input * 0)
attr_dl = np.transpose(attr_dl.squeeze(0).cpu().detach().numpy(), (1, 2, 0))

# %% [markdown]
# In the cell below we will visualize the attributions for `Saliency Maps`, `DeepLift`, `Integrated Gradients` and `Integrated Gradients with SmoothGrad`.

# %%
print('Original Image')
print('Predicted:', classes[predicted[ind]], ' Probability:',
      torch.max(F.softmax(outputs, 1)).item())

original_image = np.transpose((images[ind].cpu().detach().numpy() / 2) + 0.5,
                              (1, 2, 0))

_ = viz.visualize_image_attr(None,
                             original_image,
                             method="original_image",
                             title="Original Image")

_ = viz.visualize_image_attr(grads,
                             original_image,
                             method="blended_heat_map",
                             sign="absolute_value",
                             show_colorbar=True,
                             title="Overlayed Gradient Magnitudes")

_ = viz.visualize_image_attr(attr_ig,
                             original_image,
                             method="blended_heat_map",
                             sign="all",
                             show_colorbar=True,
                             title="Overlayed Integrated Gradients")
# Running the cell with the ``integrated_gradients.attribute()`` call will
# usually take a minute or two.
#

# Initialize the attribution algorithm with the model
integrated_gradients = IntegratedGradients(model)

# Ask the algorithm to attribute our output target to
attributions_ig = integrated_gradients.attribute(input_img,
                                                 target=pred_label_idx,
                                                 n_steps=200)

# Show the original image for comparison
_ = viz.visualize_image_attr(
    None,
    np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)),
    method="original_image",
    title="Original Image")

default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                 [(0, '#ffffff'),
                                                  (0.25, '#0000ff'),
                                                  (1, '#0000ff')],
                                                 N=256)

_ = viz.visualize_image_attr(
    np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)),
    np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)),
    method='heat_map',
    cmap=default_cmap,
    show_colorbar=True,
Exemple #24
0
    def interpret(self,
                  samples,
                  target_layer=-1,
                  show_in_notebook=False,
                  explanation_configs={},
                  vis_configs={}):
        """Explain instance and return PP or PN with metadata. If pyTorch (captum) is used,
        the convergence delta is NOT returned by default.

        Args:
            samples (tensor or tuple of tensors): Samples to explain
            target_layer (int): for KerasModel, specify the target layer. 
                                Following example in: https://github.com/marcoancona/DeepExplain/blob/master/examples/mint_cnn_keras.ipynb
            interpret_kwargs (optinal): optional arguments to pass to the explainer for attribution

        Returns:
            tensor (or tuple of tensors) containing attributions
        """
        if isinstance(self._model, TorchModel):
            if self._explainer.has_convergence_delta(
            ) and 'return_convergence_delta' not in explanation_configs:
                explanation_configs['return_convergence_delta'] = False
            explanation = self._explainer.attribute(
                inputs=self._model._prepare_sample(samples),
                **explanation_configs)
            if show_in_notebook:
                if 'return_convergence_delta' in explanation_configs and explanation_configs[
                        'return_convergence_delta']:
                    exp = explanation[0]
                else:
                    exp = explanation
                exp = np.transpose(exp.detach().numpy()[0], (1, 2, 0))
                normalizer = Normalize()
                if 'method' not in vis_configs:
                    vis_configs['method'] = 'masked_image'
                viz.visualize_image_attr(exp, normalizer(samples[0]),
                                         **vis_configs)

            return explanation
        else:
            with DeepExplain(session=K.get_session()) as de:
                input_tensor = self._model._model.inputs
                smpls = samples if isinstance(samples, list) else [samples]
                if self._method in {'occlusion', 'shapley_sampling'}:
                    warnings.warn(
                        'For perturbation methods, multiple inputs (modalities) are not supported.',
                        UserWarning)
                    smpls = smpls[0]
                    input_tensor = input_tensor[0]

                model = Model(inputs=input_tensor,
                              outputs=self._model._model.outputs)
                target_tensor = model(input_tensor)

                if show_in_notebook:
                    warnings.warn('Sorry! Visualization not implemented yet!',
                                  UserWarning)

                return de.explain(self._method,
                                  T=target_tensor,
                                  X=input_tensor,
                                  xs=smpls,
                                  **explanation_configs)