示例#1
0
文件: wsoleval.py 项目: zphang/casme
 def get_mask(self, input_, target):
     saliency_ls = []
     for j in range(len(target)):
         input_single = input_[j:j + 1]
         target_single = target[j].item()
         if self.torchray_method == "grad_cam":
             from torchray.attribution.grad_cam import grad_cam
             saliency = grad_cam(
                 model=self.original_classifier,
                 input=input_single,
                 target=target_single,
                 saliency_layer='layer4',
             )
             saliency = self.grad_cam_upsampler(saliency)
         elif self.torchray_method == "guided_backprop":
             from torchray.attribution.guided_backprop import guided_backprop
             saliency = guided_backprop(
                 model=self.original_classifier,
                 input=input_single,
                 target=target_single,
                 resize=(224, 224),
                 smooth=0.02,
             )
         else:
             raise KeyError()
         if saliency.max() == saliency.min():
             saliency[:] = 1
         else:
             saliency = (saliency - saliency.min()) / (saliency.max() -
                                                       saliency.min())
         saliency_ls.append(saliency.detach())
     return torch.cat(saliency_ls, dim=0)
示例#2
0
def compute_gradcam(model, preprocessed_image, label, saliency_layer=None):
    saliency = grad_cam(model, preprocessed_image, label, saliency_layer=saliency_layer)
    image_shape = (preprocessed_image.shape[-2], preprocessed_image.shape[-1])
    saliency = F.interpolate(saliency, image_shape, mode="bilinear", align_corners=False)
    grad = saliency.detach().cpu().clone().numpy()  # 1, 1, 8, 8 for cifar10_resnet8
    grad = np.concatenate((grad,) * 3, axis=1).squeeze()  # 3, 8, 8
    return grad
示例#3
0
def get_torchray_saliency(original_classifier, input_, target, method):
    upsampler = nn.Upsample(scale_factor=32,
                            mode='bilinear',
                            align_corners=True)
    input_, target = input_.to(device), target.to(device)

    saliency_ls = []
    for j in range(len(target)):
        input_single = input_[j:j + 1]
        target_single = target[j].item()
        if method == "grad_cam":
            from torchray.attribution.grad_cam import grad_cam
            saliency = grad_cam(
                model=original_classifier,
                input=input_single,
                target=target_single,
                saliency_layer='layer4',
            )
            saliency = upsampler(saliency)
        elif method == "guided_backprop":
            from torchray.attribution.guided_backprop import guided_backprop
            saliency = guided_backprop(
                model=original_classifier,
                input=input_single,
                target=target_single,
                resize=(224, 224),
                smooth=0.02,
            )
        else:
            raise KeyError()
        saliency_ls.append(saliency.detach())
    mask = torch.cat(saliency_ls, dim=0)

    binarized_mask = binarize_mask(mask.clone())
    rectangular = torch.empty_like(binarized_mask)
    box_coord_ls = [BoxCoords(0, 0, 0, 0)] * len(input_)

    for idx in range(mask.size(0)):
        if binarized_mask[idx].sum() == 0:
            continue

        m = binarized_mask[idx].squeeze().cpu().numpy()
        rectangular[idx], box_coord_ls[idx] = get_rectangular_mask(m)

    classifier_output = original_classifier(input_, return_intermediate=False)
    _, max_indexes = classifier_output.data.max(1)
    is_correct = target.eq(max_indexes).long()

    return (
        mask.squeeze().cpu().numpy(),
        binarized_mask.cpu().numpy(),
        rectangular.squeeze().cpu().numpy(),
        is_correct.cpu().numpy(),
        box_coord_ls,
    )
def get_attribution(model,
                    input,
                    target,
                    method,
                    device,
                    saliency_layer='features.norm5',
                    iba_wrapper=None):
    input = input.to(device)
    input.requires_grad = True

    # get attribution
    if method == "grad_cam":
        saliency_map = grad_cam(model,
                                input,
                                target,
                                saliency_layer=saliency_layer)
    elif method == "extremal_perturbation":
        saliency_map, _ = extremal_perturbation(model, input, target)
    elif method == 'ib':
        assert iba_wrapper, "Please give a iba wrapper as function parameter!"
        saliency_map = iba_wrapper.iba(model, input, target, device)
    elif method == 'reverse_ib':
        assert iba_wrapper, "Please give a iba wrapper as function parameter!"
        saliency_map = iba_wrapper.iba(model,
                                       input,
                                       target,
                                       device,
                                       reverse_lambda=True)
    elif method == "gradient":
        saliency_map = gradient(model, input, target)
    elif method == "excitation_backprop":
        saliency_map = excitation_backprop(model,
                                           input,
                                           target,
                                           saliency_layer=saliency_layer)
    elif method == "integrated_gradients":
        ig = IntegratedGradients(model)
        saliency_map, _ = ig.attribute(input,
                                       target=target,
                                       return_convergence_delta=True)
        saliency_map = saliency_map.squeeze().mean(0)

    # ib heatmap already a numpy array scaled to image size
    if method != 'ib' and method != 'reverse_ib':
        saliency_map = saliency_map.detach().cpu().numpy().squeeze()
        shape = (224, 224)
        saliency_map = resize(saliency_map,
                              shape,
                              order=1,
                              preserve_range=True)
    return saliency_map
示例#5
0
 def gcam_l3(m, x, i):
     return grad_cam(m, x, i, saliency_layer=model.encoder_q.layer3)
示例#6
0
 def gcam(m, x, i):
     g = grad_cam(m, x, i, saliency_layer=model.encoder_q.layer4)
     return g
示例#7
0
    def __next__(self):
        self._lazy_init()
        x, y = next(self.data_iterator)
        torch.manual_seed(self.seed)

        if self.log:
            from torchray.benchmark.logging import mongo_load, mongo_save, \
                data_from_mongo, data_to_mongo

        try:
            assert len(x) == 1
            x = x.to(self.device)
            class_ids = self.data.as_class_ids(y[0])
            image_size = self.data.as_image_size(y[0])

            results = {'pointing': {}, 'pointing_difficult': {}}
            info = {}
            rise_saliency = None

            for class_id in class_ids:

                # Try to recover this result from the log.
                if self.log > 0:
                    image_name = self.data.as_image_name(y[0])
                    data = mongo_load(
                        self.db,
                        self.experiment.name,
                        f"{image_name}-{class_id}",
                    )
                    if data is not None:
                        data = data_from_mongo(data)
                        results['pointing'][class_id] = data['pointing']
                        results['pointing_difficult'][class_id] = data[
                            'pointing_difficult']
                        if self.debug:
                            print(f'{image_name}-{class_id} loaded from log')
                        continue

                # TODO(av): should now be obsolete
                if x.grad is not None:
                    x.grad.data.zero_()

                if self.experiment.method == "center":
                    w, h = image_size
                    point = torch.tensor([[w / 2, h / 2]])

                elif self.experiment.method == "gradient":
                    saliency = gradient(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "deconvnet":
                    saliency = deconvnet(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "guided_backprop":
                    saliency = guided_backprop(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "grad_cam":
                    saliency = grad_cam(
                        self.model,
                        x,
                        class_id,
                        saliency_layer=self.gradcam_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "excitation_backprop":
                    saliency = excitation_backprop(
                        self.model,
                        x,
                        class_id,
                        self.saliency_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "contrastive_excitation_backprop":
                    saliency = contrastive_excitation_backprop(
                        self.model,
                        x,
                        class_id,
                        saliency_layer=self.saliency_layer,
                        contrast_layer=self.contrast_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "rise":
                    # For RISE, compute saliency map for all classes.
                    if rise_saliency is None:
                        rise_saliency = rise(self.model,
                                             x,
                                             resize=image_size,
                                             seed=self.seed)
                    saliency = rise_saliency[:, class_id, :, :].unsqueeze(1)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "extremal_perturbation":

                    if self.experiment.dataset == 'voc_2007':
                        areas = [0.025, 0.05, 0.1, 0.2]
                    else:
                        areas = [0.018, 0.025, 0.05, 0.1]

                    if self.experiment.boom:
                        raise RuntimeError("BOOM!")

                    mask, energy = elp.extremal_perturbation(
                        self.model,
                        x,
                        class_id,
                        areas=areas,
                        num_levels=8,
                        step=7,
                        sigma=7 * 3,
                        max_iter=800,
                        debug=self.debug > 0,
                        jitter=True,
                        smooth=0.09,
                        resize=image_size,
                        perturbation='blur',
                        reward_func=elp.simple_reward,
                        variant=elp.PRESERVE_VARIANT,
                    )

                    saliency = mask.sum(dim=0, keepdim=True)
                    point = _saliency_to_point(saliency)

                    info = {
                        'saliency': saliency,
                        'mask': mask,
                        'areas': areas,
                        'energy': energy
                    }

                else:
                    assert False

                if False:
                    plt.figure()
                    plt.subplot(1, 2, 1)
                    imsc(saliency[0])
                    plt.plot(point[0, 0], point[0, 1], 'ro')
                    plt.subplot(1, 2, 2)
                    imsc(x[0])
                    plt.pause(0)

                results['pointing'][class_id] = self.pointing.evaluate(
                    y[0], class_id, point[0])
                results['pointing_difficult'][
                    class_id] = self.pointing_difficult.evaluate(
                        y[0], class_id, point[0])

                if self.log > 0:
                    image_name = self.data.as_image_name(y[0])
                    mongo_save(
                        self.db, self.experiment.name,
                        f"{image_name}-{class_id}",
                        data_to_mongo({
                            'image_name':
                            image_name,
                            'class_id':
                            class_id,
                            'pointing':
                            results['pointing'][class_id],
                            'pointing_difficult':
                            results['pointing_difficult'][class_id],
                        }))

                if self.log > 1:
                    mongo_save(self.db,
                               str(self.experiment.name) + "-details",
                               f"{image_name}-{class_id}", data_to_mongo(info))

            return results

        except Exception as ex:
            raise ProcessingError(self, self.experiment, self.model, x, y,
                                  class_id, image_size) from ex
示例#8
0
from torchray.attribution.grad_cam import grad_cam
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Grad-CAM backprop.
saliency = grad_cam(model, x, category_id, saliency_layer='features.29')

# Plots.
plot_example(x, saliency, 'grad-cam backprop', category_id)
示例#9
0
from gradual_extrapolation import GradualExtrapolator
from torchvision import models

if __name__ == "__main__":
    input_batch = read_images_2_batch()

    # TODO: loop over all common networks
    model = models.__dict__["vgg16"](pretrained=True)
    model.eval()
    category_id = get_category_IDs(model, input_batch)
    deconvnet(model, input_batch, category_id)

    print("Processing Grad-CAM...")
    saliency_grad_cam = grad_cam(
        model,
        input_batch,
        category_id,
        saliency_layer="features.28",
    )
    plot_example(input_batch,
                 saliency_grad_cam,
                 "Grad-CAM",
                 category_id,
                 save_path="output_Grad-CAM.jpg")

    print("Processing Gradual Grad-CAM...")
    GradualExtrapolator.register_hooks(model)
    # we just need to process images to feed hooks
    model(input_batch)
    saliency_gradual_grad_cam = GradualExtrapolator.get_smooth_map(
        saliency_grad_cam)
def plot_map(model, dataloader, label=None, covid=False, saliency_layer=None):
    """Plot an example.

    Args:
        model: trained classification model
        dataloader: containing input images.
        label (str): Name of Category.
        covid: whether the image is from the Covid Dataset or the Chesxtray Dataset.
        saliency_layer: usually output of the last convolutional layer.
    """

    if not covid:
        FINDINGS = [
            'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
            'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
            'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
        ]
    else:
        FINDINGS = ['Detector01', 'Detector2', 'Detector3']

    try:
        if not covid:
            inputs, labels, filename, bbox = next(dataloader)
            bbox = bbox.type(torch.cuda.IntTensor)
        else:
            inputs, labels, filename = next(dataloader)
    except StopIteration:
        print(
            "All examples exhausted - rerun cells above to generate new examples to review"
        )
        return None

    original = inputs.clone()
    inputs = inputs.to(device)
    original = original.to(device)
    original.requires_grad = True

    # create predictions for label of interest and all labels
    pred = torch.sigmoid(model(original)).data.cpu().numpy()[0]
    predx = ['%.3f' % elem for elem in list(pred)]

    preds_concat = pd.concat([
        pd.Series(FINDINGS),
        pd.Series(predx),
        pd.Series(labels.numpy().astype(bool)[0])
    ],
                             axis=1)
    preds = pd.DataFrame(data=preds_concat)
    preds.columns = ["Finding", "Predicted Probability", "Ground Truth"]
    preds.set_index("Finding", inplace=True)
    preds.sort_values(by='Predicted Probability',
                      inplace=True,
                      ascending=False)

    cxr = inputs.data.cpu().numpy().squeeze().transpose(1, 2, 0)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    cxr = std * cxr + mean
    cxr = np.clip(cxr, 0, 1)

    if not covid:
        show_next(cxr, model, label, inputs, filename, bbox)

    if covid and label is None:
        label = preds.loc[preds["Ground Truth"] == True].index[0]

    category_id = FINDINGS.index(label)

    saliency = grad_cam(model,
                        original,
                        category_id,
                        saliency_layer=saliency_layer)

    fig, (showcxr, heatmap) = plt.subplots(ncols=2, figsize=(14, 5))

    showcxr.imshow(cxr)
    showcxr.axis('off')
    showcxr.set_title(filename[0])
    if not covid:
        rect_original = patches.Rectangle((bbox[0, 0], bbox[0, 1]),
                                          bbox[0, 2],
                                          bbox[0, 3],
                                          linewidth=2,
                                          edgecolor='r',
                                          facecolor='none',
                                          zorder=2)
        showcxr.add_patch(rect_original)

    hmap = sns.heatmap(saliency.detach().cpu().numpy().squeeze(),
                       cmap='viridis',
                       annot=False,
                       zorder=2,
                       linewidths=0)
    hmap.axis('off')
    hmap.set_title('TorchRay grad cam for category {}'.format(label),
                   fontsize=8)

    plt.show()

    print(preds)

    if covid:
        data_brixia = pd.read_csv("model/labels/metadata_global_v2.csv",
                                  sep=";")
        data_brixia.set_index("Filename", inplace=True)
        score = data_brixia.loc[filename[0].replace(".jpg", ".dcm"),
                                "BrixiaScore"].astype(str)
        print('Brixia 6 regions Score: ', '0' * (6 - len(score)) + score)
示例#11
0
    def plot_map(self,
                 model,
                 dataloader,
                 label,
                 methods=None,
                 covid=False,
                 regression=False,
                 saliency_layer=None,
                 overlay=False,
                 axes_a=None):
        """Plot an example.

      Args:
          model: trained classification model
          dataloader: containing input images.
          label (str): Name of Category.
          methods: list of attribution methods
          covid: whether the image is from the Covid Dataset or the Chesxtray Dataset.
          saliency_layer: usually output of the last convolutional layer.
          overlay: if display saliency map over image
          axes_a: axis of plot
      """

        if not covid:
            FINDINGS = [
                'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening',
                'Hernia'
            ]
        else:
            FINDINGS = ['Detector01', 'Detector2', 'Detector3']

        try:
            if not covid:
                inputs, labels, filename, bbox = next(dataloader)
                bbox = bbox.type(torch.cuda.IntTensor)
            else:
                inputs, labels, filename = next(dataloader)
                # for consistant return value
                bbox = None
        except StopIteration:
            print(
                "All examples exhausted - rerun cells above to generate new examples to review"
            )
            return None

        original = inputs.clone()
        inputs = inputs.to(device)
        original = original.to(device)
        original.requires_grad = True

        # create predictions for label of interest and all labels
        pred = torch.sigmoid(model(original)).data.cpu().numpy()[0]
        predx = ['%.3f' % elem for elem in list(pred)]

        preds_concat = pd.concat([
            pd.Series(FINDINGS),
            pd.Series(predx),
            pd.Series(labels.numpy().astype(bool)[0])
        ],
                                 axis=1)
        preds = pd.DataFrame(data=preds_concat)
        preds.columns = ["Finding", "Predicted Probability", "Ground Truth"]
        preds.set_index("Finding", inplace=True)
        preds.sort_values(by='Predicted Probability',
                          inplace=True,
                          ascending=False)

        # normalize image
        cxr = inputs.data.cpu().numpy().squeeze().transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        cxr = std * cxr + mean
        cxr = np.clip(cxr, 0, 1)

        # In case we want to visualize COVID, we use the highest probability as label for the visualization
        if covid and label is None:
            label = preds.loc[preds["Ground Truth"] == True].index[0]

        category_id = FINDINGS.index(label)

        # if not covid:
        #     show_iba(cxr, model, label, inputs, labels, filename, bbox)
        #     show_iba_new(cxr, model, label, inputs, labels, filename, bbox)
        #      show_next(cxr, model, label, inputs, filename, bbox)

        if methods is None:
            methods = [
                'grad-cam backprop', 'gradient', 'deconvnet',
                'excitation backprop', 'guided backprop', 'linear approx',
                'original IB', 'IB with reversed mask'
            ]

        # plot original data with bounding box, else show brixia score
        showcxr = axes_a.flatten()[0]
        showcxr.imshow(cxr)
        showcxr.axis('off')
        showcxr.set_title(filename[0])
        if not covid:
            rect_original = patches.Rectangle((bbox[0, 0], bbox[0, 1]),
                                              bbox[0, 2],
                                              bbox[0, 3],
                                              linewidth=2,
                                              edgecolor='r',
                                              facecolor='none',
                                              zorder=2)
            showcxr.add_patch(rect_original)
        else:
            scores = self.local_score_dataset.getScore(filename[0])
            color_list = ["green", "yellow", "red", "black"]
            for idx, score in enumerate(scores):
                row = (1 - idx % 3 / 2) * 0.8 + 0.1
                col = idx // 3 * 0.8 + 0.1
                plt.text(col,
                         row,
                         score,
                         color="white",
                         fontsize=36,
                         bbox=dict(facecolor=color_list[score], alpha=0.7),
                         transform=showcxr.transAxes)

        # plot visulizations
        for method, hmap in zip(methods, axes_a.flatten()[1:]):
            if method == 'grad-cam backprop':
                saliency = grad_cam(model,
                                    original,
                                    category_id,
                                    saliency_layer=saliency_layer)
            elif method == 'gradient':
                saliency = torchray_gradient(model, original, category_id)
            elif method == 'deconvnet':
                saliency = deconvnet(model, original, category_id)
            elif method == 'excitation backprop':
                saliency = excitation_backprop(model,
                                               original,
                                               category_id,
                                               saliency_layer=saliency_layer)
            elif method == 'guided backprop':
                saliency = guided_backprop(model, original, category_id)
            elif method == 'linear approx':
                saliency = linear_approx(model,
                                         original,
                                         category_id,
                                         saliency_layer=saliency_layer)
            elif method == 'original IB':
                saliency = self.iba(original,
                                    labels.squeeze(),
                                    model_loss_closure_with_target=self.
                                    softmax_crossentropy_loss_with_target)
            elif method == 'IB with reversed mask':
                saliency = self.iba(original,
                                    labels.squeeze(),
                                    reverse_lambda=True,
                                    model_loss_closure_with_target=self.
                                    softmax_crossentropy_loss_with_target)

            if not overlay:
                sns.heatmap(saliency.detach().cpu().numpy().squeeze(),
                            cmap='viridis',
                            annot=False,
                            square=True,
                            cbar=False,
                            zorder=2,
                            linewidths=0,
                            ax=hmap)
            else:
                # resize saliancy map
                if type(saliency).__module__ == np.__name__:
                    np_saliency = saliency
                else:
                    np_saliency = saliency.detach().cpu().numpy().squeeze()
                np_saliency = to_saliency_map(np_saliency, cxr.shape[:2])

                # plot saliency map with image
                plot_saliency_map(np_saliency, cxr, ax=hmap, colorbar=False)
            hmap.axis('off')
            hmap.set_title('{} for category {}'.format(method, label),
                           fontsize=12)

        return inputs, labels, filename, bbox, preds