Esempio n. 1
0
        def decnn(m, x, i):
            saliency = deconvnet(m, x, i)
            saliency = imsc(saliency[0], quiet=True)[0][None]

            return saliency
Esempio n. 2
0
    def __next__(self):
        self._lazy_init()
        x, y = next(self.data_iterator)
        torch.manual_seed(self.seed)

        if self.log:
            from torchray.benchmark.logging import mongo_load, mongo_save, \
                data_from_mongo, data_to_mongo

        try:
            assert len(x) == 1
            x = x.to(self.device)
            class_ids = self.data.as_class_ids(y[0])
            image_size = self.data.as_image_size(y[0])

            results = {'pointing': {}, 'pointing_difficult': {}}
            info = {}
            rise_saliency = None

            for class_id in class_ids:

                # Try to recover this result from the log.
                if self.log > 0:
                    image_name = self.data.as_image_name(y[0])
                    data = mongo_load(
                        self.db,
                        self.experiment.name,
                        f"{image_name}-{class_id}",
                    )
                    if data is not None:
                        data = data_from_mongo(data)
                        results['pointing'][class_id] = data['pointing']
                        results['pointing_difficult'][class_id] = data[
                            'pointing_difficult']
                        if self.debug:
                            print(f'{image_name}-{class_id} loaded from log')
                        continue

                # TODO(av): should now be obsolete
                if x.grad is not None:
                    x.grad.data.zero_()

                if self.experiment.method == "center":
                    w, h = image_size
                    point = torch.tensor([[w / 2, h / 2]])

                elif self.experiment.method == "gradient":
                    saliency = gradient(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "deconvnet":
                    saliency = deconvnet(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "guided_backprop":
                    saliency = guided_backprop(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "grad_cam":
                    saliency = grad_cam(
                        self.model,
                        x,
                        class_id,
                        saliency_layer=self.gradcam_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "excitation_backprop":
                    saliency = excitation_backprop(
                        self.model,
                        x,
                        class_id,
                        self.saliency_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "contrastive_excitation_backprop":
                    saliency = contrastive_excitation_backprop(
                        self.model,
                        x,
                        class_id,
                        saliency_layer=self.saliency_layer,
                        contrast_layer=self.contrast_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "rise":
                    # For RISE, compute saliency map for all classes.
                    if rise_saliency is None:
                        rise_saliency = rise(self.model,
                                             x,
                                             resize=image_size,
                                             seed=self.seed)
                    saliency = rise_saliency[:, class_id, :, :].unsqueeze(1)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "extremal_perturbation":

                    if self.experiment.dataset == 'voc_2007':
                        areas = [0.025, 0.05, 0.1, 0.2]
                    else:
                        areas = [0.018, 0.025, 0.05, 0.1]

                    if self.experiment.boom:
                        raise RuntimeError("BOOM!")

                    mask, energy = elp.extremal_perturbation(
                        self.model,
                        x,
                        class_id,
                        areas=areas,
                        num_levels=8,
                        step=7,
                        sigma=7 * 3,
                        max_iter=800,
                        debug=self.debug > 0,
                        jitter=True,
                        smooth=0.09,
                        resize=image_size,
                        perturbation='blur',
                        reward_func=elp.simple_reward,
                        variant=elp.PRESERVE_VARIANT,
                    )

                    saliency = mask.sum(dim=0, keepdim=True)
                    point = _saliency_to_point(saliency)

                    info = {
                        'saliency': saliency,
                        'mask': mask,
                        'areas': areas,
                        'energy': energy
                    }

                else:
                    assert False

                if False:
                    plt.figure()
                    plt.subplot(1, 2, 1)
                    imsc(saliency[0])
                    plt.plot(point[0, 0], point[0, 1], 'ro')
                    plt.subplot(1, 2, 2)
                    imsc(x[0])
                    plt.pause(0)

                results['pointing'][class_id] = self.pointing.evaluate(
                    y[0], class_id, point[0])
                results['pointing_difficult'][
                    class_id] = self.pointing_difficult.evaluate(
                        y[0], class_id, point[0])

                if self.log > 0:
                    image_name = self.data.as_image_name(y[0])
                    mongo_save(
                        self.db, self.experiment.name,
                        f"{image_name}-{class_id}",
                        data_to_mongo({
                            'image_name':
                            image_name,
                            'class_id':
                            class_id,
                            'pointing':
                            results['pointing'][class_id],
                            'pointing_difficult':
                            results['pointing_difficult'][class_id],
                        }))

                if self.log > 1:
                    mongo_save(self.db,
                               str(self.experiment.name) + "-details",
                               f"{image_name}-{class_id}", data_to_mongo(info))

            return results

        except Exception as ex:
            raise ProcessingError(self, self.experiment, self.model, x, y,
                                  class_id, image_size) from ex
Esempio n. 3
0
from torchray.attribution.deconvnet import deconvnet
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# DeConvNet method.
saliency = deconvnet(model, x, category_id)

# Plots.
plot_example(x, saliency, 'deconvnet', category_id)
Esempio n. 4
0
from torchray.attribution.grad_cam import grad_cam
from torchray.attribution.excitation_backprop import contrastive_excitation_backprop
from torchray.attribution.deconvnet import deconvnet

from utils import read_images_2_batch, get_category_IDs, plot_example
from gradual_extrapolation import GradualExtrapolator
from torchvision import models

if __name__ == "__main__":
    input_batch = read_images_2_batch()

    # TODO: loop over all common networks
    model = models.__dict__["vgg16"](pretrained=True)
    model.eval()
    category_id = get_category_IDs(model, input_batch)
    deconvnet(model, input_batch, category_id)

    print("Processing Grad-CAM...")
    saliency_grad_cam = grad_cam(
        model,
        input_batch,
        category_id,
        saliency_layer="features.28",
    )
    plot_example(input_batch,
                 saliency_grad_cam,
                 "Grad-CAM",
                 category_id,
                 save_path="output_Grad-CAM.jpg")

    print("Processing Gradual Grad-CAM...")
Esempio n. 5
0
def app(model = torchvision.models.resnet18().eval(), in_dist_name="in", ood_data_names=["out","out2"], image=True):
    # Render the readme as markdown using st.markdown.
    st.markdown(get_file_content_as_string("Incepto/dashboard/intro.md"))
    layers = get_layers(model)
    # Once we have the dependencies, add a selector for the app mode on the sidebar.
    if st.sidebar.button("Go to Guide"):
        caching.clear_cache()
        st.markdown(get_file_content_as_string("Incepto/dashboard/details.md"))

    st.sidebar.title("Data Settings")
    # select which set of SNPs to explore

    dataset = st.sidebar.selectbox(
        "Set Dataset:",
        (in_dist_name,*ood_data_names),
    )
    
    if image:
        visualization = st.sidebar.selectbox(
            "Set Visualization Type:",
            ("-", "Color Distribution for Entire Dataset", "Pixel Distribution for Entire Dataset", "Deconvolution", "Excitation Backpropgation","Gradient","Grad-CAM","Guided Backpropagation","Linear Approximation", "Extremal Perturbation", "RISE"),
        )
    else:
        visualization = st.sidebar.selectbox(
            "Set Visualization Type:",
            ("-", "Average Signal for Entire Dataset", "Deconvolution", "Excitation Backpropgation","Gradient","Grad-CAM","Guided Backpropagation","Linear Approximation", "Extremal Perturbation", "RISE"),
        )

    if image:
        if visualization == "Deconvolution":
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = deconvnet(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze()))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Grad-CAM":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = linear_approx(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Guided Backpropagation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = guided_backprop(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Gradient":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = gradient(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Linear Approximation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = gradient(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Extremal Perturbation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                masks_1, _ = extremal_perturbation(
                    model, x.cpu(), category_id,
                    reward_func=contrastive_reward,
                    debug=False,
                    areas=[0.12],)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(masks_1.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "RISE":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = rise(model, x.cpu())
                saliency = saliency[:, category_id].unsqueeze(0)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Color Distribution for Entire Dataset":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                x = sum(x)/len(x)
                image = x.cpu().detach().numpy()
                fig = plt.figure()
                mpl.rcParams.update({'font.size': 15})
                _ = plt.hist(image[:, :, 0].ravel(), bins = 256, color = 'red', alpha = 0.5)
                _ = plt.hist(image[:, :, 1].ravel(), bins = 256, color = 'Green', alpha = 0.5)
                _ = plt.hist(image[:, :, 2].ravel(), bins = 256, color = 'Blue', alpha = 0.5)
                _ = plt.xlabel('Intensity Value')
                _ = plt.ylabel('Count')
                _ = plt.legend(['Red_Channel', 'Green_Channel', 'Blue_Channel'])
                
                st.pyplot(fig)
        elif visualization == "Pixel Distribution for Entire Dataset":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                
                _, x, category_id, _ = get_example_data()
                x = sum(x)/len(x)
                image = x.cpu().detach().numpy()
                fig = plt.figure(figsize=(40,40))
                plt.ylabel("Count")
                plt.xlabel("Intensity Value")
                mpl.rcParams.update({'font.size': 55})

                ax = plt.hist(x.cpu().detach().numpy().ravel(), bins = 256)
                vlo = cv2.Laplacian(x.cpu().detach().numpy().ravel(), cv2.CV_32F).var()
                plt.text(1, 1, ('Variance of Laplacian: '+str(vlo)), style='italic', bbox={'facecolor': 'white', 'alpha': 0.5, 'pad': 10})
                st.pyplot(fig)
            mpl.rcParams.update({'font.size': 15})
            
    if st.sidebar.button("Visualize Model"):
        saliency_layer=st.selectbox("Select Layer:",tuple(layers))
        filter = st.number_input(label="Enter a filter number:", step=1, min_value=1, value=1)
        g_ascent = GradientAscent(model)
        g_ascent.use_gpu = False
        layer = model.conv1
        exec("layer = model.conv1")
        print(layer)
        img = g_ascent.visualize(layer, filter, title=saliency_layer,return_output=True)[0][0][0]
        fig = plt.figure(figsize=(40,40))
        ax = fig.add_subplot(131)
        ax.imshow(np.asarray(img.cpu().detach().numpy() ))
        st.pyplot(fig)
Esempio n. 6
0
    def plot_map(self,
                 model,
                 dataloader,
                 label,
                 methods=None,
                 covid=False,
                 regression=False,
                 saliency_layer=None,
                 overlay=False,
                 axes_a=None):
        """Plot an example.

      Args:
          model: trained classification model
          dataloader: containing input images.
          label (str): Name of Category.
          methods: list of attribution methods
          covid: whether the image is from the Covid Dataset or the Chesxtray Dataset.
          saliency_layer: usually output of the last convolutional layer.
          overlay: if display saliency map over image
          axes_a: axis of plot
      """

        if not covid:
            FINDINGS = [
                'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening',
                'Hernia'
            ]
        else:
            FINDINGS = ['Detector01', 'Detector2', 'Detector3']

        try:
            if not covid:
                inputs, labels, filename, bbox = next(dataloader)
                bbox = bbox.type(torch.cuda.IntTensor)
            else:
                inputs, labels, filename = next(dataloader)
                # for consistant return value
                bbox = None
        except StopIteration:
            print(
                "All examples exhausted - rerun cells above to generate new examples to review"
            )
            return None

        original = inputs.clone()
        inputs = inputs.to(device)
        original = original.to(device)
        original.requires_grad = True

        # create predictions for label of interest and all labels
        pred = torch.sigmoid(model(original)).data.cpu().numpy()[0]
        predx = ['%.3f' % elem for elem in list(pred)]

        preds_concat = pd.concat([
            pd.Series(FINDINGS),
            pd.Series(predx),
            pd.Series(labels.numpy().astype(bool)[0])
        ],
                                 axis=1)
        preds = pd.DataFrame(data=preds_concat)
        preds.columns = ["Finding", "Predicted Probability", "Ground Truth"]
        preds.set_index("Finding", inplace=True)
        preds.sort_values(by='Predicted Probability',
                          inplace=True,
                          ascending=False)

        # normalize image
        cxr = inputs.data.cpu().numpy().squeeze().transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        cxr = std * cxr + mean
        cxr = np.clip(cxr, 0, 1)

        # In case we want to visualize COVID, we use the highest probability as label for the visualization
        if covid and label is None:
            label = preds.loc[preds["Ground Truth"] == True].index[0]

        category_id = FINDINGS.index(label)

        # if not covid:
        #     show_iba(cxr, model, label, inputs, labels, filename, bbox)
        #     show_iba_new(cxr, model, label, inputs, labels, filename, bbox)
        #      show_next(cxr, model, label, inputs, filename, bbox)

        if methods is None:
            methods = [
                'grad-cam backprop', 'gradient', 'deconvnet',
                'excitation backprop', 'guided backprop', 'linear approx',
                'original IB', 'IB with reversed mask'
            ]

        # plot original data with bounding box, else show brixia score
        showcxr = axes_a.flatten()[0]
        showcxr.imshow(cxr)
        showcxr.axis('off')
        showcxr.set_title(filename[0])
        if not covid:
            rect_original = patches.Rectangle((bbox[0, 0], bbox[0, 1]),
                                              bbox[0, 2],
                                              bbox[0, 3],
                                              linewidth=2,
                                              edgecolor='r',
                                              facecolor='none',
                                              zorder=2)
            showcxr.add_patch(rect_original)
        else:
            scores = self.local_score_dataset.getScore(filename[0])
            color_list = ["green", "yellow", "red", "black"]
            for idx, score in enumerate(scores):
                row = (1 - idx % 3 / 2) * 0.8 + 0.1
                col = idx // 3 * 0.8 + 0.1
                plt.text(col,
                         row,
                         score,
                         color="white",
                         fontsize=36,
                         bbox=dict(facecolor=color_list[score], alpha=0.7),
                         transform=showcxr.transAxes)

        # plot visulizations
        for method, hmap in zip(methods, axes_a.flatten()[1:]):
            if method == 'grad-cam backprop':
                saliency = grad_cam(model,
                                    original,
                                    category_id,
                                    saliency_layer=saliency_layer)
            elif method == 'gradient':
                saliency = torchray_gradient(model, original, category_id)
            elif method == 'deconvnet':
                saliency = deconvnet(model, original, category_id)
            elif method == 'excitation backprop':
                saliency = excitation_backprop(model,
                                               original,
                                               category_id,
                                               saliency_layer=saliency_layer)
            elif method == 'guided backprop':
                saliency = guided_backprop(model, original, category_id)
            elif method == 'linear approx':
                saliency = linear_approx(model,
                                         original,
                                         category_id,
                                         saliency_layer=saliency_layer)
            elif method == 'original IB':
                saliency = self.iba(original,
                                    labels.squeeze(),
                                    model_loss_closure_with_target=self.
                                    softmax_crossentropy_loss_with_target)
            elif method == 'IB with reversed mask':
                saliency = self.iba(original,
                                    labels.squeeze(),
                                    reverse_lambda=True,
                                    model_loss_closure_with_target=self.
                                    softmax_crossentropy_loss_with_target)

            if not overlay:
                sns.heatmap(saliency.detach().cpu().numpy().squeeze(),
                            cmap='viridis',
                            annot=False,
                            square=True,
                            cbar=False,
                            zorder=2,
                            linewidths=0,
                            ax=hmap)
            else:
                # resize saliancy map
                if type(saliency).__module__ == np.__name__:
                    np_saliency = saliency
                else:
                    np_saliency = saliency.detach().cpu().numpy().squeeze()
                np_saliency = to_saliency_map(np_saliency, cxr.shape[:2])

                # plot saliency map with image
                plot_saliency_map(np_saliency, cxr, ax=hmap, colorbar=False)
            hmap.axis('off')
            hmap.set_title('{} for category {}'.format(method, label),
                           fontsize=12)

        return inputs, labels, filename, bbox, preds