コード例 #1
0
from torchray.attribution.excitation_backprop import contrastive_excitation_backprop
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Contrastive excitation backprop.
saliency = contrastive_excitation_backprop(
    model,
    x,
    category_id,
    saliency_layer='features.9',
    contrast_layer='features.30',
    classifier_layer='classifier.6',
)

# Plots.
plot_example(x, saliency, 'contrastive excitation backprop', category_id)
コード例 #2
0
from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward
from torchray.benchmark import get_example_data, plot_example
from torchray.utils import get_device

# Obtain example data.
model, x, category_id_1, category_id_2 = get_example_data()

# Run on GPU if available.
device = get_device()
model.to(device)
x = x.to(device)

# Extremal perturbation backprop.
masks_1, _ = extremal_perturbation(
    model,
    x,
    category_id_1,
    reward_func=contrastive_reward,
    debug=True,
    areas=[0.12],
)

masks_2, _ = extremal_perturbation(
    model,
    x,
    category_id_2,
    reward_func=contrastive_reward,
    debug=True,
    areas=[0.05],
)
コード例 #3
0
ファイル: app.py プロジェクト: anas-awadalla/Incepto
def app(model = torchvision.models.resnet18().eval(), in_dist_name="in", ood_data_names=["out","out2"], image=True):
    # Render the readme as markdown using st.markdown.
    st.markdown(get_file_content_as_string("Incepto/dashboard/intro.md"))
    layers = get_layers(model)
    # Once we have the dependencies, add a selector for the app mode on the sidebar.
    if st.sidebar.button("Go to Guide"):
        caching.clear_cache()
        st.markdown(get_file_content_as_string("Incepto/dashboard/details.md"))

    st.sidebar.title("Data Settings")
    # select which set of SNPs to explore

    dataset = st.sidebar.selectbox(
        "Set Dataset:",
        (in_dist_name,*ood_data_names),
    )
    
    if image:
        visualization = st.sidebar.selectbox(
            "Set Visualization Type:",
            ("-", "Color Distribution for Entire Dataset", "Pixel Distribution for Entire Dataset", "Deconvolution", "Excitation Backpropgation","Gradient","Grad-CAM","Guided Backpropagation","Linear Approximation", "Extremal Perturbation", "RISE"),
        )
    else:
        visualization = st.sidebar.selectbox(
            "Set Visualization Type:",
            ("-", "Average Signal for Entire Dataset", "Deconvolution", "Excitation Backpropgation","Gradient","Grad-CAM","Guided Backpropagation","Linear Approximation", "Extremal Perturbation", "RISE"),
        )

    if image:
        if visualization == "Deconvolution":
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = deconvnet(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze()))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Grad-CAM":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = linear_approx(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Guided Backpropagation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = guided_backprop(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Gradient":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = gradient(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Linear Approximation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = gradient(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Extremal Perturbation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                masks_1, _ = extremal_perturbation(
                    model, x.cpu(), category_id,
                    reward_func=contrastive_reward,
                    debug=False,
                    areas=[0.12],)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(masks_1.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "RISE":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = rise(model, x.cpu())
                saliency = saliency[:, category_id].unsqueeze(0)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Color Distribution for Entire Dataset":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                x = sum(x)/len(x)
                image = x.cpu().detach().numpy()
                fig = plt.figure()
                mpl.rcParams.update({'font.size': 15})
                _ = plt.hist(image[:, :, 0].ravel(), bins = 256, color = 'red', alpha = 0.5)
                _ = plt.hist(image[:, :, 1].ravel(), bins = 256, color = 'Green', alpha = 0.5)
                _ = plt.hist(image[:, :, 2].ravel(), bins = 256, color = 'Blue', alpha = 0.5)
                _ = plt.xlabel('Intensity Value')
                _ = plt.ylabel('Count')
                _ = plt.legend(['Red_Channel', 'Green_Channel', 'Blue_Channel'])
                
                st.pyplot(fig)
        elif visualization == "Pixel Distribution for Entire Dataset":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                
                _, x, category_id, _ = get_example_data()
                x = sum(x)/len(x)
                image = x.cpu().detach().numpy()
                fig = plt.figure(figsize=(40,40))
                plt.ylabel("Count")
                plt.xlabel("Intensity Value")
                mpl.rcParams.update({'font.size': 55})

                ax = plt.hist(x.cpu().detach().numpy().ravel(), bins = 256)
                vlo = cv2.Laplacian(x.cpu().detach().numpy().ravel(), cv2.CV_32F).var()
                plt.text(1, 1, ('Variance of Laplacian: '+str(vlo)), style='italic', bbox={'facecolor': 'white', 'alpha': 0.5, 'pad': 10})
                st.pyplot(fig)
            mpl.rcParams.update({'font.size': 15})
            
    if st.sidebar.button("Visualize Model"):
        saliency_layer=st.selectbox("Select Layer:",tuple(layers))
        filter = st.number_input(label="Enter a filter number:", step=1, min_value=1, value=1)
        g_ascent = GradientAscent(model)
        g_ascent.use_gpu = False
        layer = model.conv1
        exec("layer = model.conv1")
        print(layer)
        img = g_ascent.visualize(layer, filter, title=saliency_layer,return_output=True)[0][0][0]
        fig = plt.figure(figsize=(40,40))
        ax = fig.add_subplot(131)
        ax.imshow(np.asarray(img.cpu().detach().numpy() ))
        st.pyplot(fig)