def generate_extremal_perturbation(model, x, y, area=0.1):
    masks, cam = extremal_perturbation(
        model, x, y,
        reward_func=contrastive_reward,
        debug=True,
        areas=[area],
    )
    return cam[0]
def compute_extremal_perturbation(model,
                                  preprocessed_image,
                                  label,
                                  saliency_layer=None):
    saliency, _ = extremal_perturbation(model, preprocessed_image,
                                        label.item())
    grad = saliency.detach().cpu().clone().numpy()
    grad = np.concatenate((grad, ) * 3, axis=1).squeeze()
    return grad
Beispiel #3
0
            def _ep(m, x, i):
                # Extremal perturbation backprop.
                masks_1, _ = extremal_perturbation(
                    m,
                    x,
                    i,
                    reward_func=contrastive_reward,
                    debug=True,
                    areas=[area],
                )
                masks_1 = imsc(masks_1[0], quiet=True)[0][None]

                return masks_1
def get_attribution(model,
                    input,
                    target,
                    method,
                    device,
                    saliency_layer='features.norm5',
                    iba_wrapper=None):
    input = input.to(device)
    input.requires_grad = True

    # get attribution
    if method == "grad_cam":
        saliency_map = grad_cam(model,
                                input,
                                target,
                                saliency_layer=saliency_layer)
    elif method == "extremal_perturbation":
        saliency_map, _ = extremal_perturbation(model, input, target)
    elif method == 'ib':
        assert iba_wrapper, "Please give a iba wrapper as function parameter!"
        saliency_map = iba_wrapper.iba(model, input, target, device)
    elif method == 'reverse_ib':
        assert iba_wrapper, "Please give a iba wrapper as function parameter!"
        saliency_map = iba_wrapper.iba(model,
                                       input,
                                       target,
                                       device,
                                       reverse_lambda=True)
    elif method == "gradient":
        saliency_map = gradient(model, input, target)
    elif method == "excitation_backprop":
        saliency_map = excitation_backprop(model,
                                           input,
                                           target,
                                           saliency_layer=saliency_layer)
    elif method == "integrated_gradients":
        ig = IntegratedGradients(model)
        saliency_map, _ = ig.attribute(input,
                                       target=target,
                                       return_convergence_delta=True)
        saliency_map = saliency_map.squeeze().mean(0)

    # ib heatmap already a numpy array scaled to image size
    if method != 'ib' and method != 'reverse_ib':
        saliency_map = saliency_map.detach().cpu().numpy().squeeze()
        shape = (224, 224)
        saliency_map = resize(saliency_map,
                              shape,
                              order=1,
                              preserve_range=True)
    return saliency_map
from torchray.benchmark import get_example_data, plot_example
from torchray.utils import get_device

# Obtain example data.
model, x, category_id_1, category_id_2 = get_example_data()

# Run on GPU if available.
device = get_device()
model.to(device)
x = x.to(device)

# Extremal perturbation backprop.
masks_1, _ = extremal_perturbation(
    model,
    x,
    category_id_1,
    reward_func=contrastive_reward,
    debug=True,
    areas=[0.12],
)

masks_2, _ = extremal_perturbation(
    model,
    x,
    category_id_2,
    reward_func=contrastive_reward,
    debug=True,
    areas=[0.05],
)

# Plots.
plot_example(x, masks_1, 'extremal perturbation', category_id_1)
Beispiel #6
0
    def __next__(self):
        self._lazy_init()
        x, y = next(self.data_iterator)
        torch.manual_seed(self.seed)

        if self.log:
            from torchray.benchmark.logging import mongo_load, mongo_save, \
                data_from_mongo, data_to_mongo

        try:
            assert len(x) == 1
            x = x.to(self.device)
            class_ids = self.data.as_class_ids(y[0])
            image_size = self.data.as_image_size(y[0])

            results = {'pointing': {}, 'pointing_difficult': {}}
            info = {}
            rise_saliency = None

            for class_id in class_ids:

                # Try to recover this result from the log.
                if self.log > 0:
                    image_name = self.data.as_image_name(y[0])
                    data = mongo_load(
                        self.db,
                        self.experiment.name,
                        f"{image_name}-{class_id}",
                    )
                    if data is not None:
                        data = data_from_mongo(data)
                        results['pointing'][class_id] = data['pointing']
                        results['pointing_difficult'][class_id] = data[
                            'pointing_difficult']
                        if self.debug:
                            print(f'{image_name}-{class_id} loaded from log')
                        continue

                # TODO(av): should now be obsolete
                if x.grad is not None:
                    x.grad.data.zero_()

                if self.experiment.method == "center":
                    w, h = image_size
                    point = torch.tensor([[w / 2, h / 2]])

                elif self.experiment.method == "gradient":
                    saliency = gradient(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "deconvnet":
                    saliency = deconvnet(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "guided_backprop":
                    saliency = guided_backprop(
                        self.model,
                        x,
                        class_id,
                        resize=image_size,
                        smooth=0.02,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "grad_cam":
                    saliency = grad_cam(
                        self.model,
                        x,
                        class_id,
                        saliency_layer=self.gradcam_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "excitation_backprop":
                    saliency = excitation_backprop(
                        self.model,
                        x,
                        class_id,
                        self.saliency_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "contrastive_excitation_backprop":
                    saliency = contrastive_excitation_backprop(
                        self.model,
                        x,
                        class_id,
                        saliency_layer=self.saliency_layer,
                        contrast_layer=self.contrast_layer,
                        resize=image_size,
                        get_backward_gradient=get_pointing_gradient)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "rise":
                    # For RISE, compute saliency map for all classes.
                    if rise_saliency is None:
                        rise_saliency = rise(self.model,
                                             x,
                                             resize=image_size,
                                             seed=self.seed)
                    saliency = rise_saliency[:, class_id, :, :].unsqueeze(1)
                    point = _saliency_to_point(saliency)
                    info['saliency'] = saliency

                elif self.experiment.method == "extremal_perturbation":

                    if self.experiment.dataset == 'voc_2007':
                        areas = [0.025, 0.05, 0.1, 0.2]
                    else:
                        areas = [0.018, 0.025, 0.05, 0.1]

                    if self.experiment.boom:
                        raise RuntimeError("BOOM!")

                    mask, energy = elp.extremal_perturbation(
                        self.model,
                        x,
                        class_id,
                        areas=areas,
                        num_levels=8,
                        step=7,
                        sigma=7 * 3,
                        max_iter=800,
                        debug=self.debug > 0,
                        jitter=True,
                        smooth=0.09,
                        resize=image_size,
                        perturbation='blur',
                        reward_func=elp.simple_reward,
                        variant=elp.PRESERVE_VARIANT,
                    )

                    saliency = mask.sum(dim=0, keepdim=True)
                    point = _saliency_to_point(saliency)

                    info = {
                        'saliency': saliency,
                        'mask': mask,
                        'areas': areas,
                        'energy': energy
                    }

                else:
                    assert False

                if False:
                    plt.figure()
                    plt.subplot(1, 2, 1)
                    imsc(saliency[0])
                    plt.plot(point[0, 0], point[0, 1], 'ro')
                    plt.subplot(1, 2, 2)
                    imsc(x[0])
                    plt.pause(0)

                results['pointing'][class_id] = self.pointing.evaluate(
                    y[0], class_id, point[0])
                results['pointing_difficult'][
                    class_id] = self.pointing_difficult.evaluate(
                        y[0], class_id, point[0])

                if self.log > 0:
                    image_name = self.data.as_image_name(y[0])
                    mongo_save(
                        self.db, self.experiment.name,
                        f"{image_name}-{class_id}",
                        data_to_mongo({
                            'image_name':
                            image_name,
                            'class_id':
                            class_id,
                            'pointing':
                            results['pointing'][class_id],
                            'pointing_difficult':
                            results['pointing_difficult'][class_id],
                        }))

                if self.log > 1:
                    mongo_save(self.db,
                               str(self.experiment.name) + "-details",
                               f"{image_name}-{class_id}", data_to_mongo(info))

            return results

        except Exception as ex:
            raise ProcessingError(self, self.experiment, self.model, x, y,
                                  class_id, image_size) from ex
Beispiel #7
0
def for_vis(args):
    transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
    ])
    # Con-text
    if args.dataset == 'ConText':
        train, val = MakeList(args).get_data()
        dataset_val = ConText(val, transform=transform)
        data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                      args.batch_size,
                                                      shuffle=False,
                                                      num_workers=1,
                                                      pin_memory=True)
        data = iter(data_loader_val).next()
        image = data["image"][0]
        label = data["label"][0]
        image_orl = Image.fromarray(
            (image.cpu().detach().numpy() * 255).astype(np.uint8).transpose(
                (1, 2, 0)),
            mode='RGB')
        image = transform(image_orl)
        transform = transforms.Compose([
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    elif args.dataset == 'ImageNet':
        train, val = MakeListImage(args).get_data()
        dataset_val = ConText(val, transform=transform)
        data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                      args.batch_size,
                                                      shuffle=False,
                                                      num_workers=1,
                                                      pin_memory=True)
        iter_loader = iter(data_loader_val)
        for i in range(0, 1):
            data = iter_loader.next()
        image = data["image"][0]
        label = data["label"][0].item()
        image_orl = Image.fromarray(
            (image.cpu().detach().numpy() * 255).astype(np.uint8).transpose(
                (1, 2, 0)),
            mode='RGB')
        image = transform(image_orl)
        transform = transforms.Compose([
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    # MNIST
    elif args.dataset == 'MNIST':
        dataset_val = datasets.MNIST('./data/mnist',
                                     train=False,
                                     transform=transform)
        data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                      args.batch_size,
                                                      shuffle=False,
                                                      num_workers=1,
                                                      pin_memory=True)
        image = iter(data_loader_val).next()[0][0]
        label = ''
        image_orl = Image.fromarray(
            (image.cpu().detach().numpy() * 255).astype(np.uint8)[0], mode='L')
        image = transform(image_orl)
        transform = transforms.Compose(
            [transforms.Normalize((0.1307, ), (0.3081, ))])
    # CUB
    elif args.dataset == 'CUB200':
        dataset_val = CUB_200(args, train=False, transform=transform)
        data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                      args.batch_size,
                                                      shuffle=False,
                                                      num_workers=1,
                                                      pin_memory=True)
        data = iter(data_loader_val).next()
        image = data["image"][0]
        label = data["label"][0]
        image_orl = Image.fromarray(
            (image.cpu().detach().numpy() * 255).astype(np.uint8).transpose(
                (1, 2, 0)),
            mode='RGB')
        image = transform(image_orl)
        transform = transforms.Compose([
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    image = transform(image)
    image = image.unsqueeze(0)
    device = torch.device(args.device)

    ### IGOS
    model = load_backbone(args)
    model = model.to(device)
    model.eval()

    image_orl_for_blur = np.float32(image_orl) / 255.
    img, blurred_img, logitori = Get_blurred_img(image_orl_for_blur,
                                                 label,
                                                 model,
                                                 resize_shape=(260, 260),
                                                 Gaussian_param=[51, 50],
                                                 Median_param=11,
                                                 blur_type='Gaussian',
                                                 use_cuda=1)

    for target_index in tqdm(range(0, args.num_classes)):
        mask, upsampled_mask, imgratio, curvetop, curve1, curve2, category = Integrated_Mask(
            img,
            blurred_img,
            model,
            label,
            max_iterations=15,
            integ_iter=20,
            tv_beta=2,
            l1_coeff=0.01 * 100,
            tv_coeff=0.2 * 100,
            size_init=8,
            use_cuda=1)  #
        mask = upsampled_mask.cpu().detach().numpy()[0, 0]
        mask = -mask + mask.max() * 2.

        mask = np.maximum(mask, 0)
        mask = mask - np.min(mask)
        mask = mask / np.max(mask)
        mask = np.maximum(mask, args.grad_min_level)
        mask = mask - np.min(mask)
        mask = mask / np.max(mask)
        # mask = Image.fromarray(mask*255, mode='L').resize((args.img_size, args.img_size), Image.BILINEAR)
        # mask = np.uint8(mask)

        image_orl = image_orl.resize((args.img_size, args.img_size),
                                     Image.BILINEAR)
        # heatmap = np.array(heatmap)
        show_cam_on_image(image_orl, mask, target_index, 'IGOS')

    del model

    ### torchray (RISE)
    model = load_backbone(args)
    model = model.to(device)
    model.eval()

    for target_index in tqdm(range(0, args.num_classes)):
        mask = rise(model, image.to(device), target_index)
        mask = mask.cpu().numpy()[0, 0]

        mask = np.maximum(mask, 0)
        mask = mask - np.min(mask)
        mask = mask / np.max(mask)
        mask = np.maximum(mask, args.grad_min_level)
        mask = mask - np.min(mask)
        mask = mask / np.max(mask)

        image_orl = image_orl.resize((args.img_size, args.img_size),
                                     Image.BILINEAR)
        # heatmap = np.array(heatmap)
        show_cam_on_image(image_orl, mask, target_index, 'RISE')

    del model

    ### torchray (Extremal)
    model = load_backbone(args)
    model = model.to(device)
    model.eval()

    for target_index in tqdm(range(0, args.num_classes)):
        mask, _ = extremal_perturbation(model, image.to(device), target_index)
        mask = mask.cpu().numpy()[0, 0]

        mask = np.maximum(mask, 0)
        mask = mask - np.min(mask)
        mask = mask / np.max(mask)
        mask = np.maximum(mask, args.grad_min_level)
        mask = mask - np.min(mask)
        mask = mask / np.max(mask)

        image_orl = image_orl.resize((args.img_size, args.img_size),
                                     Image.BILINEAR)
        # heatmap = np.array(heatmap)
        show_cam_on_image(image_orl, mask, target_index, 'Extremal')

    del model

    ### IBA
    model = load_backbone(args)
    model = model.to(device)
    model.eval()

    imagenet_dir = '../../data/imagenet/ILSVRC/Data/CLS-LOC/validation'
    # Add a Per-Sample Bottleneck at layer conv4_1
    iba = IBA(model.layer4)

    # Estimate the mean and variance of the feature map at this layer.
    val_set = get_imagenet_folder(imagenet_dir)
    val_loader = DataLoader(val_set,
                            batch_size=64,
                            shuffle=True,
                            num_workers=4)
    iba.estimate(model, val_loader, n_samples=5000, progbar=True)

    for target_index in tqdm(range(0, args.num_classes)):
        # Closure that returns the loss for one batch
        model_loss_closure = lambda x: -torch.log_softmax(
            model(x.to(device)), dim=1)[:, target_index].mean()
        # Explain class target for the given image
        saliency_map = iba.analyze(image, model_loss_closure, beta=10)
        # display result
        model_loss_closure = lambda x: -torch.log_softmax(
            model(x.to(device)), 1)[:, target_index].mean()
        heatmap = iba.analyze(image, model_loss_closure)

        mask = heatmap
        mask = np.maximum(mask, 0)
        mask = mask - np.min(mask)
        mask = mask / np.max(mask)
        mask = np.maximum(mask, args.grad_min_level)
        mask = mask - np.min(mask)
        mask = mask / np.max(mask)

        image_orl = image_orl.resize((args.img_size, args.img_size),
                                     Image.BILINEAR)
        # heatmap = np.array(heatmap)
        show_cam_on_image(image_orl, mask, target_index, 'IBA')
        # plot_saliency_map(heatmap, tensor_to_np_img(image[0]))

    RESNET_CONFIG = dict(input_layer='conv1',
                         conv_layer='layer4',
                         fc_layer='fc')

    MODEL_CONFIG = {**RESNET_CONFIG}
    conv_layer = MODEL_CONFIG['conv_layer']
    input_layer = MODEL_CONFIG['input_layer']
    fc_layer = MODEL_CONFIG['fc_layer']

    ### torchcam
    del model
    model = load_backbone(args)
    model = model.to(device)
    model.eval()
    # Hook the corresponding layer in the model
    cam_extractors = [
        CAM(model, conv_layer, fc_layer),
        GradCAM(model, conv_layer),
        GradCAMpp(model, conv_layer),
        SmoothGradCAMpp(model, conv_layer, input_layer),
        ScoreCAM(model, conv_layer, input_layer),
        SSCAM(model, conv_layer, input_layer),
        #   ISSCAM(model, conv_layer, input_layer),
    ]
    cam_extractors_names = [
        'CAM',
        'GradCAM',
        'GradCAMpp',
        'SmoothGradCAMpp',
        'ScoreCAM',
        'SSCAM',
        #   'ISSCAM',
    ]
    for idx, extractor in enumerate(cam_extractors):
        model.zero_grad()

        output1 = model(image.to(device))
        output = F.softmax(output1, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)

        pred_label_idx.squeeze_()
        predicted_label = str(pred_label_idx.item())
        print('Predicted:', predicted_label, '(',
              prediction_score.squeeze().item(), ')')

        make_grad(extractor, output1, image_orl, args.grad_min_level,
                  cam_extractors_names[idx])
        extractor.clear_hooks()
Beispiel #8
0
def app(model = torchvision.models.resnet18().eval(), in_dist_name="in", ood_data_names=["out","out2"], image=True):
    # Render the readme as markdown using st.markdown.
    st.markdown(get_file_content_as_string("Incepto/dashboard/intro.md"))
    layers = get_layers(model)
    # Once we have the dependencies, add a selector for the app mode on the sidebar.
    if st.sidebar.button("Go to Guide"):
        caching.clear_cache()
        st.markdown(get_file_content_as_string("Incepto/dashboard/details.md"))

    st.sidebar.title("Data Settings")
    # select which set of SNPs to explore

    dataset = st.sidebar.selectbox(
        "Set Dataset:",
        (in_dist_name,*ood_data_names),
    )
    
    if image:
        visualization = st.sidebar.selectbox(
            "Set Visualization Type:",
            ("-", "Color Distribution for Entire Dataset", "Pixel Distribution for Entire Dataset", "Deconvolution", "Excitation Backpropgation","Gradient","Grad-CAM","Guided Backpropagation","Linear Approximation", "Extremal Perturbation", "RISE"),
        )
    else:
        visualization = st.sidebar.selectbox(
            "Set Visualization Type:",
            ("-", "Average Signal for Entire Dataset", "Deconvolution", "Excitation Backpropgation","Gradient","Grad-CAM","Guided Backpropagation","Linear Approximation", "Extremal Perturbation", "RISE"),
        )

    if image:
        if visualization == "Deconvolution":
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = deconvnet(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze()))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Grad-CAM":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = linear_approx(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Guided Backpropagation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = guided_backprop(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Gradient":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = gradient(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Linear Approximation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = gradient(model, x.cpu(), category_id, saliency_layer=saliency_layer)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
            st.pyplot(fig)
        elif visualization == "Extremal Perturbation":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                masks_1, _ = extremal_perturbation(
                    model, x.cpu(), category_id,
                    reward_func=contrastive_reward,
                    debug=False,
                    areas=[0.12],)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(masks_1.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "RISE":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                saliency = rise(model, x.cpu())
                saliency = saliency[:, category_id].unsqueeze(0)
                fig = plt.figure(figsize=(40,40))
                ax = fig.add_subplot(131)
                ax.imshow(np.asarray(saliency.squeeze().detach().numpy() ))
                ax = fig.add_subplot(132)
                ax.imshow(np.asarray(x.cpu().squeeze().permute(1,2,0).detach().numpy() ))
                st.pyplot(fig)
        elif visualization == "Color Distribution for Entire Dataset":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                # saliency_layer=st.selectbox("Select Layer:",tuple(layers))
                # st.number_input(label="Enter a channel number:", step=1, min_value=0, value=0)
                _, x, category_id, _ = get_example_data()
                x = sum(x)/len(x)
                image = x.cpu().detach().numpy()
                fig = plt.figure()
                mpl.rcParams.update({'font.size': 15})
                _ = plt.hist(image[:, :, 0].ravel(), bins = 256, color = 'red', alpha = 0.5)
                _ = plt.hist(image[:, :, 1].ravel(), bins = 256, color = 'Green', alpha = 0.5)
                _ = plt.hist(image[:, :, 2].ravel(), bins = 256, color = 'Blue', alpha = 0.5)
                _ = plt.xlabel('Intensity Value')
                _ = plt.ylabel('Count')
                _ = plt.legend(['Red_Channel', 'Green_Channel', 'Blue_Channel'])
                
                st.pyplot(fig)
        elif visualization == "Pixel Distribution for Entire Dataset":
            with st.spinner("Generating Plot"):
                caching.clear_cache()
                
                _, x, category_id, _ = get_example_data()
                x = sum(x)/len(x)
                image = x.cpu().detach().numpy()
                fig = plt.figure(figsize=(40,40))
                plt.ylabel("Count")
                plt.xlabel("Intensity Value")
                mpl.rcParams.update({'font.size': 55})

                ax = plt.hist(x.cpu().detach().numpy().ravel(), bins = 256)
                vlo = cv2.Laplacian(x.cpu().detach().numpy().ravel(), cv2.CV_32F).var()
                plt.text(1, 1, ('Variance of Laplacian: '+str(vlo)), style='italic', bbox={'facecolor': 'white', 'alpha': 0.5, 'pad': 10})
                st.pyplot(fig)
            mpl.rcParams.update({'font.size': 15})
            
    if st.sidebar.button("Visualize Model"):
        saliency_layer=st.selectbox("Select Layer:",tuple(layers))
        filter = st.number_input(label="Enter a filter number:", step=1, min_value=1, value=1)
        g_ascent = GradientAscent(model)
        g_ascent.use_gpu = False
        layer = model.conv1
        exec("layer = model.conv1")
        print(layer)
        img = g_ascent.visualize(layer, filter, title=saliency_layer,return_output=True)[0][0][0]
        fig = plt.figure(figsize=(40,40))
        ax = fig.add_subplot(131)
        ax.imshow(np.asarray(img.cpu().detach().numpy() ))
        st.pyplot(fig)