from torchray.attribution.excitation_backprop import contrastive_excitation_backprop from torchray.benchmark import get_example_data, plot_example # Obtain example data. model, x, category_id, _ = get_example_data() # Contrastive excitation backprop. saliency = contrastive_excitation_backprop( model, x, category_id, saliency_layer='features.9', contrast_layer='features.30', classifier_layer='classifier.6', ) # Plots. plot_example(x, saliency, 'contrastive excitation backprop', category_id)
from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward from torchray.benchmark import get_example_data, plot_example from torchray.utils import get_device # Obtain example data. model, x, category_id_1, category_id_2 = get_example_data() # Run on GPU if available. device = get_device() model.to(device) x = x.to(device) # Extremal perturbation backprop. masks_1, _ = extremal_perturbation( model, x, category_id_1, reward_func=contrastive_reward, debug=True, areas=[0.12], ) masks_2, _ = extremal_perturbation( model, x, category_id_2, reward_func=contrastive_reward, debug=True, areas=[0.05], )
def app(model = torchvision.models.resnet18().eval(), in_dist_name="in", ood_data_names=["out","out2"], image=True): # Render the readme as markdown using st.markdown. st.markdown(get_file_content_as_string("Incepto/dashboard/intro.md")) layers = get_layers(model) # Once we have the dependencies, add a selector for the app mode on the sidebar. if st.sidebar.button("Go to Guide"): caching.clear_cache() st.markdown(get_file_content_as_string("Incepto/dashboard/details.md")) st.sidebar.title("Data Settings") # select which set of SNPs to explore dataset = st.sidebar.selectbox( "Set Dataset:", (in_dist_name,*ood_data_names), ) if image: visualization = st.sidebar.selectbox( "Set Visualization Type:", ("-", "Color Distribution for Entire Dataset", "Pixel Distribution for Entire Dataset", "Deconvolution", "Excitation Backpropgation","Gradient","Grad-CAM","Guided Backpropagation","Linear Approximation", "Extremal Perturbation", "RISE"), ) else: visualization = st.sidebar.selectbox( "Set Visualization Type:", ("-", "Average Signal for Entire Dataset", "Deconvolution", "Excitation Backpropgation","Gradient","Grad-CAM","Guided Backpropagation","Linear Approximation", "Extremal Perturbation", "RISE"), ) if image: if visualization == "Deconvolution": caching.clear_cache() saliency_layer=st.selectbox("Select Layer:",tuple(layers)) # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0) _, x, category_id, _ = get_example_data() saliency = deconvnet(model, x.cpu(), category_id, saliency_layer=saliency_layer) fig = plt.figure(figsize=(40,40)) ax = fig.add_subplot(131) ax.imshow(np.asarray(saliency.squeeze())) ax = fig.add_subplot(132) ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() )) st.pyplot(fig) elif visualization == "Grad-CAM": with st.spinner("Generating Plot"): caching.clear_cache() saliency_layer=st.selectbox("Select Layer:",tuple(layers)) # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0) _, x, category_id, _ = get_example_data() saliency = linear_approx(model, x.cpu(), category_id, saliency_layer=saliency_layer) fig = plt.figure(figsize=(40,40)) ax = fig.add_subplot(131) ax.imshow(np.asarray(saliency.squeeze().detach().numpy() )) ax = fig.add_subplot(132) ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() )) st.pyplot(fig) elif visualization == "Guided Backpropagation": with st.spinner("Generating Plot"): caching.clear_cache() saliency_layer=st.selectbox("Select Layer:",tuple(layers)) # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0) _, x, category_id, _ = get_example_data() saliency = guided_backprop(model, x.cpu(), category_id, saliency_layer=saliency_layer) fig = plt.figure(figsize=(40,40)) ax = fig.add_subplot(131) ax.imshow(np.asarray(saliency.squeeze().detach().numpy() )) ax = fig.add_subplot(132) ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() )) st.pyplot(fig) elif visualization == "Gradient": with st.spinner("Generating Plot"): caching.clear_cache() saliency_layer=st.selectbox("Select Layer:",tuple(layers)) # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0) _, x, category_id, _ = get_example_data() saliency = gradient(model, x.cpu(), category_id, saliency_layer=saliency_layer) fig = plt.figure(figsize=(40,40)) ax = fig.add_subplot(131) ax.imshow(np.asarray(saliency.squeeze().detach().numpy() )) ax = fig.add_subplot(132) ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() )) st.pyplot(fig) elif visualization == "Linear Approximation": with st.spinner("Generating Plot"): caching.clear_cache() saliency_layer=st.selectbox("Select Layer:",tuple(layers)) # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0) _, x, category_id, _ = get_example_data() saliency = gradient(model, x.cpu(), category_id, saliency_layer=saliency_layer) fig = plt.figure(figsize=(40,40)) ax = fig.add_subplot(131) ax.imshow(np.asarray(saliency.squeeze().detach().numpy() )) ax = fig.add_subplot(132) ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() )) st.pyplot(fig) elif visualization == "Extremal Perturbation": with st.spinner("Generating Plot"): caching.clear_cache() # saliency_layer=st.selectbox("Select Layer:",tuple(layers)) # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0) _, x, category_id, _ = get_example_data() masks_1, _ = extremal_perturbation( model, x.cpu(), category_id, reward_func=contrastive_reward, debug=False, areas=[0.12],) fig = plt.figure(figsize=(40,40)) ax = fig.add_subplot(131) ax.imshow(np.asarray(masks_1.squeeze().detach().numpy() )) ax = fig.add_subplot(132) ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() )) st.pyplot(fig) elif visualization == "RISE": with st.spinner("Generating Plot"): caching.clear_cache() # saliency_layer=st.selectbox("Select Layer:",tuple(layers)) # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0) _, x, category_id, _ = get_example_data() saliency = rise(model, x.cpu()) saliency = saliency[:, category_id].unsqueeze(0) fig = plt.figure(figsize=(40,40)) ax = fig.add_subplot(131) ax.imshow(np.asarray(saliency.squeeze().detach().numpy() )) ax = fig.add_subplot(132) ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() )) st.pyplot(fig) elif visualization == "Color Distribution for Entire Dataset": with st.spinner("Generating Plot"): caching.clear_cache() # saliency_layer=st.selectbox("Select Layer:",tuple(layers)) # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0) _, x, category_id, _ = get_example_data() x = sum(x)/len(x) image = x.cpu().detach().numpy() fig = plt.figure() mpl.rcParams.update({'font.size': 15}) _ = plt.hist(image[:, :, 0].ravel(), bins = 256, color = 'red', alpha = 0.5) _ = plt.hist(image[:, :, 1].ravel(), bins = 256, color = 'Green', alpha = 0.5) _ = plt.hist(image[:, :, 2].ravel(), bins = 256, color = 'Blue', alpha = 0.5) _ = plt.xlabel('Intensity Value') _ = plt.ylabel('Count') _ = plt.legend(['Red_Channel', 'Green_Channel', 'Blue_Channel']) st.pyplot(fig) elif visualization == "Pixel Distribution for Entire Dataset": with st.spinner("Generating Plot"): caching.clear_cache() _, x, category_id, _ = get_example_data() x = sum(x)/len(x) image = x.cpu().detach().numpy() fig = plt.figure(figsize=(40,40)) plt.ylabel("Count") plt.xlabel("Intensity Value") mpl.rcParams.update({'font.size': 55}) ax = plt.hist(x.cpu().detach().numpy().ravel(), bins = 256) vlo = cv2.Laplacian(x.cpu().detach().numpy().ravel(), cv2.CV_32F).var() plt.text(1, 1, ('Variance of Laplacian: '+str(vlo)), style='italic', bbox={'facecolor': 'white', 'alpha': 0.5, 'pad': 10}) st.pyplot(fig) mpl.rcParams.update({'font.size': 15}) if st.sidebar.button("Visualize Model"): saliency_layer=st.selectbox("Select Layer:",tuple(layers)) filter = st.number_input(label="Enter a filter number:", step=1, min_value=1, value=1) g_ascent = GradientAscent(model) g_ascent.use_gpu = False layer = model.conv1 exec("layer = model.conv1") print(layer) img = g_ascent.visualize(layer, filter, title=saliency_layer,return_output=True)[0][0][0] fig = plt.figure(figsize=(40,40)) ax = fig.add_subplot(131) ax.imshow(np.asarray(img.cpu().detach().numpy() )) st.pyplot(fig)