Пример #1
0
def guided_backprop(model, image, results_dir, file_name):

    gbp = GuidedBackprop(model)
    attributions = gbp.attribute(image)

    attributions = attributions.squeeze()
    attributions = attributions.numpy()
    attributions = nib.Nifti1Image(attributions, affine=np.eye(4))

    nib.save(attributions, results_dir + file_name + "-HM.nii")
Пример #2
0
    def test_convnet_multi_target_and_default_pert_func(self) -> None:
        r"""
        Similar to previous example but here we also test default
        perturbation function.
        """
        model = BasicModel_ConvNet_One_Conv()
        gbp = GuidedBackprop(model)

        input = torch.stack([torch.arange(1, 17).float()] * 20, dim=0).view(20, 1, 4, 4)

        sens1 = self.sensitivity_max_assert(
            gbp.attribute,
            input,
            torch.zeros(20),
            perturb_func=default_perturb_func,
            target=torch.tensor([1] * 20),
            n_perturb_samples=10,
            max_examples_per_batch=40,
        )

        sens2 = self.sensitivity_max_assert(
            gbp.attribute,
            input,
            torch.zeros(20),
            perturb_func=default_perturb_func,
            target=torch.tensor([1] * 20),
            n_perturb_samples=10,
            max_examples_per_batch=5,
        )
        assertTensorAlmostEqual(self, sens1, sens2)
Пример #3
0
def compute_attr_one_pixel_target(x, net, spatial_coords, method, **kwargs):
    # x is a single input tensor, i.e. shape (batch_size,channel_size,H,W)=(1,1,28,28)
    from captum.attr import LayerGradCam, Deconvolution, GuidedBackprop

    if 'target' in kwargs:
        target = kwargs['target']
    
    if 'wrapper_output' in kwargs:
        output_mode = kwargs['wrapper_output']
    else:
        output_mode = 'yg_pixel'

    idx,idy = spatial_coords
    wnet = WrapperNet(net, output_mode=output_mode, spatial_coords=(idx,idy))
    
    if method=='gradCAM':
        xai = LayerGradCam(wnet, wnet.main_net.channel_adj)
    elif method=='deconv':
        xai = Deconvolution(wnet)
    elif method=='GuidedBP':
        xai = GuidedBackprop(wnet)
    
    if method in ['gradCAM', 'deconv', 'GuidedBP']:
        attr = xai.attribute(x, target=target )
    elif method == 'layerAct':
        attr = xai.attribute(x)

    attr = attr[0][0].clone().detach().cpu().numpy()
    return attr
Пример #4
0
    def __init__(self, model, dataset=None, gpu=-1):

        self.model = model

        # Put model in evaluation mode
        self.model.eval()

        # data_loader = DataLoader(dataset,batch_size=len(dataset))
        # itr = iter(data_loader)
        # X,y = next(itr)

        # if len(X.shape) != 4:
        #     raise IllegalDataDeminsionException("Expected data to have deminsions 4 but got deminsion ",len(X.shape))

        if gpu == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:' + str(gpu))

        self.algo = GuidedBackprop(model)
Пример #5
0
class Explainer():
    def __init__(self,model):
        self.model=model
        self.explain=GuidedBackprop(model)


    def get_attribution_map(self,img,target=None):
        if target is None:
            target=torch.argmax(self.model(img),1)
        attributions = self.explain.attribute(img, target=target)
        return attributions
Пример #6
0
def compute_GBP(model,
                dataloader,
                num_state_inputs=54,
                model_type="S",
                no_pbar=False):
    ''' 
    Computes the Guided Backpropogation values - only used for models with state information.
    right now only works for force only predictions - can be adapted for torques as well
    '''
    tqdm.write('Computing guided backprop...')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        tqdm.write("using GPU acceleration")

    model = model.to(device, dtype=torch.float)
    model.eval()

    gbp_array = [np.empty((0, num_state_inputs)) for i in range(3)]
    gbp_norm = [[] for i in range(3)]
    gbp = GuidedBackprop(model)

    for n in range(3):
        with tqdm(total=len(dataloader), disable=no_pbar) as pbar:
            for img, state, labels in dataloader:
                state = state.to(device, dtype=torch.float)
                if model_type != "S":
                    img = img.to(device, dtype=torch.float)
                    attributions = gbp.attribute((img, state), target=n)
                    gbp_array[n] = np.vstack(
                        (gbp_array[n], attributions[1].cpu().numpy()))
                else:
                    attributions = gbp.attribute(state, target=n)
                    gbp_array[n] = np.vstack(
                        (gbp_array[n], attributions.cpu().numpy()))
                pbar.update(1)

        gbp_norm[n] = gbp_array[n] - np.min(gbp_array[n], axis=0)
        gbp_norm[n] = gbp_norm[n] / np.reshape(np.sum(gbp_norm[n], axis=1),
                                               (-1, 1))

    return gbp_norm
Пример #7
0
class ImageInterpreter(object):
    def __init__(self, model, dataset=None, gpu=-1):

        self.model = model

        # Put model in evaluation mode
        self.model.eval()

        # data_loader = DataLoader(dataset,batch_size=len(dataset))
        # itr = iter(data_loader)
        # X,y = next(itr)

        # if len(X.shape) != 4:
        #     raise IllegalDataDeminsionException("Expected data to have deminsions 4 but got deminsion ",len(X.shape))

        if gpu == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:' + str(gpu))

        self.algo = GuidedBackprop(model)

    def interpret_image(self, image, target_class, result_dir):
        # Get gradients
        guided_grads = self.algo.attribute(
            image, target_class).detach().cpu().numpy()[0]
        print(guided_grads.shape)
        # Save colored gradients
        save_gradient_images(guided_grads, result_dir + '_Guided_BP_color')
        # Convert to grayscale
        grayscale_guided_grads = convert_to_grayscale(
            guided_grads * image.detach().numpy()[0])
        # Save grayscale gradients
        save_gradient_images(grayscale_guided_grads,
                             result_dir + '_Guided_BP_gray')
        # Positive and negative saliency maps
        pos_sal, neg_sal = get_positive_negative_saliency(guided_grads)
        save_gradient_images(pos_sal, result_dir + '_pos_sal')
        save_gradient_images(neg_sal, result_dir + '_neg_sal')
        print('Guided backprop completed')
Пример #8
0
def heatmaps(args):
    print('heatmaps')
    PROJECT_ID = args['PROJECT_ID']

    CKPT_DIR, PROJECT_DIR, MODEL_DIR, LOGGER_DIR, load_model = folder_check(PROJECT_ID, CKPT_DIR='checkpoint')
    XAI_DIR = os.path.join(PROJECT_DIR,'XAI')
    if not os.path.exists(XAI_DIR): os.mkdir(XAI_DIR)

    from .sampler import Pytorch_GPT_MNIST_Sampler
    samp = Pytorch_GPT_MNIST_Sampler(compenv_mode=None, growth_mode=None)

    from .model import ResGPTNet34
    net = ResGPTNet34(nG0=samp.gen.nG0, Nj=samp.gen.N_neighbor)
    net = torch.load(MODEL_DIR)
    net.output_mode = 'prediction_only'
    net.to(device=device)
    net.eval()

    x, y0, yg0, ys0 = samp.get_sample_batch(class_indices=np.array(range(10)), device=device)
    x.requires_grad=True

    attrs = {}
    SAVE_DIR = os.path.join(XAI_DIR, 'heatmaps.y0.jpeg')

    from captum.attr import LayerGradCam, Deconvolution, GuidedBackprop # ShapleyValueSampling

    xai = LayerGradCam(net, net.channel_adj)
    attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy()
    attrs['gradCAM'] = attr

    xai = Deconvolution(net)
    attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy()
    attrs['deconv'] = attr

    xai = GuidedBackprop(net)
    attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy()
    attrs['GuidedBP'] = attr

    arrange_heatmaps(x.clone().detach().cpu().numpy() , attrs, save_dir=SAVE_DIR)
Пример #9
0
            ch_type,
            data_type,
            seed=seed,
            num_workers=num_workers,
            chunkload=chunkload,
            include=(0, 1, 0),
            ages=ages,
        )

        model_filepath = save_path + f"{net_name}.pt"
        logging.info(net.name)

        _, net_state, _ = load_checkpoint(model_filepath)
        net.load_state_dict(net_state)

        gbp = GuidedBackprop(net)
        gbps = []
        targets = []
        for i, batch in enumerate(validloader):
            X, y = batch
            X = X.view(-1, *net.input_size).to(device)
            y = y.to(torch.int64).to(device)
            for trial, target in zip(X, y):
                gbps.append(
                    gbp.attribute(torch.unsqueeze(trial.float(), 0),
                                  target).cpu())
                targets.append(target.cpu().numpy())

            gbps = torch.cat(gbps).numpy()
            savemat(save_path + f"gbp_{net_name}.mat", {
                "gbp": gbps,
Пример #10
0
def measure_filter_model(
    model_version,
    dataset,
    out_folder,
    weights_dir,
    device,
    method=METHODS["gradcam"],
    sample_images=50,
    step=1,
    use_infidelity=False,
    use_sensitivity=False,
    render=False,
    ids=None,
):
    invTrans = get_inverse_normalization_transformation()
    data_dir = os.path.join("data")

    if model_version == "resnet18":
        model = create_resnet18_model(num_of_classes=NUM_OF_CLASSES[dataset])
    elif model_version == "resnet50":
        model = create_resnet50_model(num_of_classes=NUM_OF_CLASSES[dataset])
    elif model_version == "densenet":
        model = create_densenet121_model(
            num_of_classes=NUM_OF_CLASSES[dataset])
    else:
        model = create_efficientnetb0_model(
            num_of_classes=NUM_OF_CLASSES[dataset])

    model.load_state_dict(torch.load(weights_dir))

    # print(model)

    model.eval()
    model.to(device)

    test_dataset = CustomDataset(
        dataset=dataset,
        transformer=get_default_transformation(),
        data_type="test",
        root_dir=data_dir,
        step=step,
        add_filters=True,
        ids=ids,
    )
    data_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=4)

    try:
        image_ids = random.sample(range(0, test_dataset.__len__()),
                                  test_dataset.__len__())
    except ValueError:
        raise ValueError(
            f"Image sample number ({test_dataset.__len__()}) exceeded dataset size ({test_dataset.__len__()})."
        )

    classes_map = test_dataset.classes_map

    print(f"Measuring {model_version} on {dataset} dataset, with {method}")
    print("-" * 10)
    pbar = tqdm(total=test_dataset.__len__(), desc="Model test completion")
    multipy_by_inputs = False
    if method == METHODS["ig"]:
        attr_method = IntegratedGradients(model)
        nt_samples = 1
        n_perturb_samples = 1
    if method == METHODS["saliency"]:
        attr_method = Saliency(model)
        nt_samples = 8
        n_perturb_samples = 2
    if method == METHODS["gradcam"]:
        if model_version == "efficientnet":
            attr_method = GuidedGradCam(model, model._conv_stem)
        elif model_version == "densenet":
            attr_method = GuidedGradCam(model, model.features.conv0)
        else:
            attr_method = GuidedGradCam(model, model.conv1)
        nt_samples = 8
        n_perturb_samples = 2
    if method == METHODS["deconv"]:
        attr_method = Deconvolution(model)
        nt_samples = 8
        n_perturb_samples = 2
    if method == METHODS["gradshap"]:
        attr_method = GradientShap(model)
        nt_samples = 8
        n_perturb_samples = 2
    if method == METHODS["gbp"]:
        attr_method = GuidedBackprop(model)
        nt_samples = 8
        n_perturb_samples = 2
    if method == "lime":
        attr_method = Lime(model)
        nt_samples = 8
        n_perturb_samples = 2
        feature_mask = torch.tensor(lime_mask).to(device)
        multipy_by_inputs = True
    if method == METHODS["ig"]:
        nt = attr_method
    else:
        nt = NoiseTunnel(attr_method)
    scores = []

    @infidelity_perturb_func_decorator(multipy_by_inputs=multipy_by_inputs)
    def perturb_fn(inputs):
        noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float()
        noise = noise.to(device)
        return inputs - noise

    OUR_FILTERS = [
        "none",
        "fx_freaky_details 2,10,1,11,0,32,0",
        "normalize_local 8,10",
        "fx_boost_chroma 90,0,0",
        "fx_mighty_details 25,1,25,1,11,0",
        "sharpen 300",
    ]
    idx = 0
    filter_count = 0
    filter_attrs = {filter_name: [] for filter_name in OUR_FILTERS}
    predicted_main_class = 0
    for input, label in data_loader:
        pbar.update(1)
        inv_input = invTrans(input)
        input = input.to(device)
        input.requires_grad = True
        output = model(input)
        output = F.softmax(output, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)
        prediction_score = prediction_score.cpu().detach().numpy()[0][0]
        pred_label_idx.squeeze_()
        if OUR_FILTERS[filter_count] == 'none':
            predicted_main_class = pred_label_idx.item()

        if method == METHODS["gradshap"]:
            baseline = torch.randn(input.shape)
            baseline = baseline.to(device)

        if method == "lime":
            attributions = attr_method.attribute(input, target=1, n_samples=50)
        elif method == METHODS["ig"]:
            attributions = nt.attribute(
                input,
                target=predicted_main_class,
                n_steps=25,
            )
        elif method == METHODS["gradshap"]:
            attributions = nt.attribute(input,
                                        target=predicted_main_class,
                                        baselines=baseline)
        else:
            attributions = nt.attribute(
                input,
                nt_type="smoothgrad",
                nt_samples=nt_samples,
                target=predicted_main_class,
            )

        if use_infidelity:
            infid = infidelity(model,
                               perturb_fn,
                               input,
                               attributions,
                               target=predicted_main_class)
            inf_value = infid.cpu().detach().numpy()[0]
        else:
            inf_value = 0

        if use_sensitivity:
            if method == "lime":
                sens = sensitivity_max(
                    attr_method.attribute,
                    input,
                    target=predicted_main_class,
                    n_perturb_samples=1,
                    n_samples=200,
                    feature_mask=feature_mask,
                )
            elif method == METHODS["ig"]:
                sens = sensitivity_max(
                    nt.attribute,
                    input,
                    target=predicted_main_class,
                    n_perturb_samples=n_perturb_samples,
                    n_steps=25,
                )
            elif method == METHODS["gradshap"]:
                sens = sensitivity_max(
                    nt.attribute,
                    input,
                    target=predicted_main_class,
                    n_perturb_samples=n_perturb_samples,
                    baselines=baseline,
                )
            else:
                sens = sensitivity_max(
                    nt.attribute,
                    input,
                    target=predicted_main_class,
                    n_perturb_samples=n_perturb_samples,
                )
            sens_value = sens.cpu().detach().numpy()[0]
        else:
            sens_value = 0

        # filter_name = test_dataset.data.iloc[pbar.n]["filter"].split(" ")[0]
        attr_data = attributions.squeeze().cpu().detach().numpy()
        if render:
            fig, ax = viz.visualize_image_attr_multiple(
                np.transpose(attr_data, (1, 2, 0)),
                np.transpose(inv_input.squeeze().cpu().detach().numpy(),
                             (1, 2, 0)),
                ["original_image", "heat_map"],
                ["all", "positive"],
                titles=["original_image", "heat_map"],
                cmap=default_cmap,
                show_colorbar=True,
                use_pyplot=False,
                fig_size=(8, 6),
            )
            if use_sensitivity or use_infidelity:
                ax[0].set_xlabel(
                    f"Infidelity: {'{0:.6f}'.format(inf_value)}\n Sensitivity: {'{0:.6f}'.format(sens_value)}"
                )
            fig.suptitle(
                f"True: {classes_map[str(label.numpy()[0])][0]}, Pred: {classes_map[str(pred_label_idx.item())][0]}\nScore: {'{0:.4f}'.format(prediction_score)}",
                fontsize=16,
            )
            fig.savefig(
                os.path.join(
                    out_folder,
                    f"{str(idx)}-{str(filter_count)}-{str(label.numpy()[0])}-{str(OUR_FILTERS[filter_count])}-{classes_map[str(label.numpy()[0])][0]}-{classes_map[str(pred_label_idx.item())][0]}.png",
                ))
            plt.close(fig)
        # if pbar.n > 25:
        #     break
        score_for_true_label = output.cpu().detach().numpy(
        )[0][predicted_main_class]

        filter_attrs[OUR_FILTERS[filter_count]] = [
            np.moveaxis(attr_data, 0, -1),
            "{0:.8f}".format(score_for_true_label),
        ]

        data_range_for_current_set = MAX_ATT_VALUES[model_version][method][
            dataset]
        filter_count += 1
        if filter_count >= len(OUR_FILTERS):
            ssims = []
            for rot in OUR_FILTERS:
                ssims.append("{0:.8f}".format(
                    ssim(
                        filter_attrs["none"][0],
                        filter_attrs[rot][0],
                        win_size=11,
                        data_range=data_range_for_current_set,
                        multichannel=True,
                    )))
                ssims.append(filter_attrs[rot][1])

            scores.append(ssims)
            filter_count = 0
            predicted_main_class = 0
            idx += 1

    pbar.close()

    indexes = []

    for filter_name in OUR_FILTERS:
        indexes.append(str(filter_name) + "-ssim")
        indexes.append(str(filter_name) + "-score")
    np.savetxt(
        os.path.join(
            out_folder,
            f"{model_version}-{dataset}-{method}-ssim-with-range.csv"),
        np.array(scores),
        delimiter=";",
        fmt="%s",
        header=";".join([str(rot) for rot in indexes]),
    )

    print(f"Artifacts stored at {out_folder}")
Пример #11
0
 def __init__(self,model):
     self.model=model
     self.explain=GuidedBackprop(model)
Пример #12
0
class SignalInterpreter(object):
    """Summary of class here.

        Longer class information....
        Longer class information....

        Attributes:
            likes_spam: A boolean indicating if we like SPAM or not.
            eggs: An integer count of the eggs we have laid.
        """

    def __init__(self, model, X, y, gpu, channel_labels):
        """

        Args:
            model:
            dataset:
            gpu:
            channel_labels:
        """
        # data_loader = DataLoader(dataset, batch_size=len(dataset))
        # itr = iter(data_loader)
        self.X, self.y = X, y
        
        self.algo = GuidedBackprop(model)
        
        
        self.model = model.eval()

        if len(self.X.shape) != 3:
            raise IllegalDataDimensionException("Expected data to have dimensions 3 but got dimension ",
                                                str(len(self.X.shape)))

        if len(self.X.shape) != len(channel_labels):
            raise IllegalDataDimensionException("Expected channel_labels to contain ", str(len(self.X.shape)),
                                                " labels but got ", str(len(channel_labels)), " labels instead.")

        self.channel_labels = channel_labels

        if gpu == -1:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda:' + str(gpu))
            
        model.to(self.device)

        counter = 0
        self.modules = {}
        for top, module in model.named_children():
            if type(module) == torch.nn.Sequential:
                self.modules.update({"" + top: module})
                counter += 1

        print(f"Total Interpretable layers: {counter}")

        self.X.to(self.device)

        layer_map = {0: self.X.flatten(start_dim=1)}
        index = 1
        output = self.X

        for i in self.modules:
            layer = self.modules[i]
            layer.eval().to(self.device)
            output = layer(output)
            layer_map.update({index: output.flatten(start_dim=1)})
            index += 1

        assert counter == len(layer_map) - 1
        print("Intermediate Forward Pass Through " + str(len(layer_map) - 1) + " Layers")

        print("Generating PCA...")
        self.pca_layer_result = []

        self.pca = PCA(n_components=3)
        self.pca_result = self.pca.fit_transform(layer_map[len(layer_map) - 1].detach().numpy())
        for i in range(len(channel_labels)):
            self.pca_layer_result.append(self.pca_result[:,i])
        
        print("Generation Complete")

    def interpret_signal(self, signal):
        """

        Args:
            signal:

        Returns:

        """
        
        output = signal.unsqueeze(0).to(self.device)

        for i in self.modules:
            layer = self.modules[i]
            layer.eval().to(self.device)
            output = layer(output)
            
        output = torch.flatten(output,start_dim=1)
        
        reduced_signal = self.pca.transform(output.detach())
        example_index = self.__closest_point(reduced_signal)
        example_signal = self.X[example_index]
        
        signal_attribution = self.algo.attribute(signal.unsqueeze(0), target=0).detach().cpu().numpy()[0]
        example_signal_attribution = self.algo.attribute(example_signal.unsqueeze(0), target=0).detach().cpu().numpy()[0]
        
        signal_attribution = convert_to_grayscale(signal_attribution* signal.detach().numpy())
        example_signal_attribution = convert_to_grayscale(example_signal_attribution* example_signal.detach().numpy())
        
        # save_gradient_images(grayscale_guided_grads, result_dir + '_Guided_BP_gray')
        
        return self.__plot_signals([signal, example_signal], [signal_attribution, example_signal_attribution]
                                   , int(torch.round(torch.sigmoid(self.model(signal.unsqueeze(0).to(self.device)))).item())
                                   , self.y[example_index].item(), self.channel_labels)

    def __closest_point(self, point):
        """

        Args:
            point:
            points:

        Returns:

        """
        points = np.asarray(self.pca_result)
        deltas = points - point
        dist_2 = np.einsum('ij,ij->i', deltas, deltas)
        return np.argmin(dist_2)

    def __plot_signals(self, signals, attribution, output_test, output_training, channel_labels):
        """

        Args:
            signals:
            channel_labels:

        Returns:
            object:

        """
        colors = list(CSS4_COLORS.keys())
        i = 0
        fig = plt.figure(figsize=(40, 40))
        ax = fig.add_subplot(4, 4, i + 1)
        i += 1

        color_index = 0
        ax.set_title("Interpreted Signal with Output Class "+str(output_test))
        # for channel, label in zip(signals[0], channel_labels):
        ax.plot(sum(signals[0]).detach().cpu())
        ax.plot(sum(attribution[0]), color="r")
        
            # color_index += 1
            
        # plt.legend()

        ax = fig.add_subplot(4, 4, i + 1)
        i += 1

        color_index = 0
        ax.set_title("Example Signal with Output Class "+str(output_training))
        # for channel, label in zip(signals[1], channel_labels):
        #     ax.plot(channel, color=colors[color_index + 20], label=label)
        #     color_index += 1
        ax.plot(sum(signals[1]))
        ax.plot(sum(attribution[1]), color="r")

        # plt.legend()
        plt.show()
        return plt
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    criterion = torch.nn.CrossEntropyLoss()
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    gbp = GuidedBackprop(net)

    #    ggCAM = GuidedGradCam(net)

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        label_test = data_test['label']
        label_test = label_test.type(torch.FloatTensor)
        print(label_test.shape)
Пример #14
0
def get_attribution(real_img, 
                    fake_img, 
                    real_class, 
                    fake_class, 
                    net_module, 
                    checkpoint_path, 
                    input_shape, 
                    channels,
                    methods=["ig", "grads", "gc", "ggc", "dl", "ingrad", "random", "residual"],
                    output_classes=6,
                    downsample_factors=[(2,2), (2,2), (2,2), (2,2)]):


    imgs = [image_to_tensor(normalize_image(real_img).astype(np.float32)), 
            image_to_tensor(normalize_image(fake_img).astype(np.float32))]

    classes = [real_class, fake_class]
    net = init_network(checkpoint_path, input_shape, net_module, channels, output_classes=output_classes,eval_net=True, require_grad=False,
                       downsample_factors=downsample_factors)

    attrs = []
    attrs_names = []

    if "residual" in methods:
        res = np.abs(real_img - fake_img)
        res = res - np.min(res)
        attrs.append(torch.tensor(res/np.max(res)))
        attrs_names.append("residual")

    if "random" in methods:
        rand = np.abs(np.random.randn(*np.shape(real_img)))
        rand = np.abs(scipy.ndimage.filters.gaussian_filter(rand, 4))
        rand = rand - np.min(rand)
        rand = rand/np.max(np.abs(rand))
        attrs.append(torch.tensor(rand))
        attrs_names.append("random")

    if "gc" in methods:
        net.zero_grad()
        last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1]
        layer_name = last_conv_layer[0]
        layer = last_conv_layer[1]
        layer_gc = LayerGradCam(net, layer)
        gc_real = layer_gc.attribute(imgs[0], target=classes[0])
        gc_fake = layer_gc.attribute(imgs[1], target=classes[1])

        gc_real = project_layer_activations_to_input_rescale(gc_real.cpu().detach().numpy(), (input_shape[0], input_shape[1]))
        gc_fake = project_layer_activations_to_input_rescale(gc_fake.cpu().detach().numpy(), (input_shape[0], input_shape[1]))

        attrs.append(torch.tensor(gc_real[0,0,:,:]))
        attrs_names.append("gc_real")

        attrs.append(torch.tensor(gc_fake[0,0,:,:]))
        attrs_names.append("gc_fake")

        # SCAM
        gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, 
                                     fake_class, net_module, checkpoint_path, 
                                     input_shape, channels, None, output_classes=output_classes,
                                     downsample_factors=downsample_factors)
        attrs.append(gc_diff_0)
        attrs_names.append("gc_diff_0")

        attrs.append(gc_diff_1)
        attrs_names.append("gc_diff_1")

    if "ggc" in methods:
        net.zero_grad()
        last_conv = [module for module in net.modules() if type(module) == torch.nn.Conv2d][-1]
        guided_gc = GuidedGradCam(net, last_conv)
        ggc_real = guided_gc.attribute(imgs[0], target=classes[0])
        ggc_fake = guided_gc.attribute(imgs[1], target=classes[1])

        attrs.append(ggc_real[0,0,:,:])
        attrs_names.append("ggc_real")

        attrs.append(ggc_fake[0,0,:,:])
        attrs_names.append("ggc_fake")

        net.zero_grad()
        gbp = GuidedBackprop(net)
        gbp_real = gbp.attribute(imgs[0], target=classes[0])
        gbp_fake = gbp.attribute(imgs[1], target=classes[1])
        
        attrs.append(gbp_real[0,0,:,:])
        attrs_names.append("gbp_real")

        attrs.append(gbp_fake[0,0,:,:])
        attrs_names.append("gbp_fake")

        ggc_diff_0 = gbp_real[0,0,:,:] * gc_diff_0
        ggc_diff_1 = gbp_fake[0,0,:,:] * gc_diff_1

        attrs.append(ggc_diff_0)
        attrs_names.append("ggc_diff_0")

        attrs.append(ggc_diff_1)
        attrs_names.append("ggc_diff_1")

    # IG
    if "ig" in methods:
        baseline = image_to_tensor(np.zeros(input_shape, dtype=np.float32))
        net.zero_grad()
        ig = IntegratedGradients(net)
        ig_real, delta_real = ig.attribute(imgs[0], baseline, target=classes[0], return_convergence_delta=True)
        ig_fake, delta_fake = ig.attribute(imgs[1], baseline, target=classes[1], return_convergence_delta=True)
        ig_diff_0, delta_diff = ig.attribute(imgs[0], imgs[1], target=classes[0], return_convergence_delta=True)
        ig_diff_1, delta_diff = ig.attribute(imgs[1], imgs[0], target=classes[1], return_convergence_delta=True)

        attrs.append(ig_real[0,0,:,:])
        attrs_names.append("ig_real")

        attrs.append(ig_fake[0,0,:,:])
        attrs_names.append("ig_fake")

        attrs.append(ig_diff_0[0,0,:,:])
        attrs_names.append("ig_diff_0")

        attrs.append(ig_diff_1[0,0,:,:])
        attrs_names.append("ig_diff_1")

        
    # DL
    if "dl" in methods:
        net.zero_grad()
        dl = DeepLift(net)
        dl_real = dl.attribute(imgs[0], target=classes[0])
        dl_fake = dl.attribute(imgs[1], target=classes[1])
        dl_diff_0 = dl.attribute(imgs[0], baselines=imgs[1], target=classes[0])
        dl_diff_1 = dl.attribute(imgs[1], baselines=imgs[0], target=classes[1])

        attrs.append(dl_real[0,0,:,:])
        attrs_names.append("dl_real")

        attrs.append(dl_fake[0,0,:,:])
        attrs_names.append("dl_fake")

        attrs.append(dl_diff_0[0,0,:,:])
        attrs_names.append("dl_diff_0")

        attrs.append(dl_diff_1[0,0,:,:])
        attrs_names.append("dl_diff_1")

    # INGRAD
    if "ingrad" in methods:
        net.zero_grad()
        saliency = Saliency(net)
        grads_real = saliency.attribute(imgs[0], 
                                        target=classes[0]) 
        grads_fake = saliency.attribute(imgs[1], 
                                        target=classes[1]) 

        attrs.append(grads_real[0,0,:,:])
        attrs_names.append("grads_real")

        attrs.append(grads_fake[0,0,:,:])
        attrs_names.append("grads_fake")

        net.zero_grad()
        input_x_gradient = InputXGradient(net)
        ingrad_real = input_x_gradient.attribute(imgs[0], target=classes[0])
        ingrad_fake = input_x_gradient.attribute(imgs[1], target=classes[1])

        ingrad_diff_0 = grads_fake * (imgs[0] - imgs[1])
        ingrad_diff_1 = grads_real * (imgs[1] - imgs[0])

        attrs.append(torch.abs(ingrad_real[0,0,:,:]))
        attrs_names.append("ingrad_real")

        attrs.append(torch.abs(ingrad_fake[0,0,:,:]))
        attrs_names.append("ingrad_fake")

        attrs.append(torch.abs(ingrad_diff_0[0,0,:,:]))
        attrs_names.append("ingrad_diff_0")

        attrs.append(torch.abs(ingrad_diff_1[0,0,:,:]))
        attrs_names.append("ingrad_diff_1")

    attrs = [a.detach().cpu().numpy() for a in attrs]
    attrs_norm = [a/np.max(np.abs(a)) for a in attrs]

    return attrs_norm, attrs_names
Пример #15
0
def evaluation_ten_classes(initiate_or_load_model,
                           config_data,
                           singleton_scope=False,
                           reshape_size=None,
                           FIND_OPTIM_BRANCH_MODEL=False,
                           realtime_update=False,
                           ALLOW_ADHOC_NOPTIM=False):
    from pipeline.training.training_utils import prepare_save_dirs
    xai_mode = config_data['xai_mode']
    MODEL_DIR, INFO_DIR, CACHE_FOLDER_DIR = prepare_save_dirs(config_data)

    ############################
    VERBOSE = 0
    ############################

    if not FIND_OPTIM_BRANCH_MODEL:
        print(
            'Using the following the model from (only) continuous training for xai evaluation [%s]'
            % (str(xai_mode)))
        net, evaluator = initiate_or_load_model(MODEL_DIR,
                                                INFO_DIR,
                                                config_data,
                                                verbose=VERBOSE)
    else:
        BRANCH_FOLDER_DIR = MODEL_DIR[:MODEL_DIR.find('.model')] + '.%s' % (
            str(config_data['branch_name_label']))
        BRANCH_MODEL_DIR = os.path.join(
            BRANCH_FOLDER_DIR,
            '%s.%s.model' % (str(config_data['model_name']),
                             str(config_data['branch_name_label'])))
        # BRANCH_MODEL_DIR = MODEL_DIR[:MODEL_DIR.find('.model')] + '.%s.model'%(str(config_data['branch_name_label']))

        if ALLOW_ADHOC_NOPTIM:  # this is intended only for debug runs
            print('<< [EXY1] ALLOWING ADHOC NOPTIM >>')
            import shutil
            shutil.copyfile(BRANCH_MODEL_DIR, BRANCH_MODEL_DIR + '.noptim')

        if os.path.exists(BRANCH_MODEL_DIR + '.optim'):
            BRANCH_MODEL_DIR = BRANCH_MODEL_DIR + '.optim'
            print(
                '  Using the OPTIMIZED branch model for [%s] xai evaluation: %s'
                % (str(xai_mode), str(BRANCH_MODEL_DIR)))
        elif os.path.exists(BRANCH_MODEL_DIR + '.noptim'):
            BRANCH_MODEL_DIR = BRANCH_MODEL_DIR + '.noptim'
            print(
                '  Using the partially optimized branch model for [%s] xai evaluation: %s'
                % (str(xai_mode), str(BRANCH_MODEL_DIR)))
        else:
            raise RuntimeError(
                'Attempting to find .optim or .noptim model, but not found.')
        if VERBOSE >= 250:
            print(
                '  """You may see a warning by pytorch for ReLu backward hook. It has been fixed externally, so you can ignore it."""'
            )
        net, evaluator = initiate_or_load_model(BRANCH_MODEL_DIR,
                                                INFO_DIR,
                                                config_data,
                                                verbose=VERBOSE)

    if xai_mode == 'Saliency': attrmodel = Saliency(net)
    elif xai_mode == 'IntegratedGradients':
        attrmodel = IntegratedGradients(net)
    elif xai_mode == 'InputXGradient':
        attrmodel = InputXGradient(net)
    elif xai_mode == 'DeepLift':
        attrmodel = DeepLift(net)
    elif xai_mode == 'GuidedBackprop':
        attrmodel = GuidedBackprop(net)
    elif xai_mode == 'GuidedGradCam':
        attrmodel = GuidedGradCam(net, net.select_first_layer())  # first layer
    elif xai_mode == 'Deconvolution':
        attrmodel = Deconvolution(net)
    elif xai_mode == 'GradientShap':
        attrmodel = GradientShap(net)
    elif xai_mode == 'DeepLiftShap':
        attrmodel = DeepLiftShap(net)
    else:
        raise RuntimeError('No valid attribution selected.')

    if singleton_scope:  # just to observe a single datapoint, mostly for debugging
        singleton_scope_oberservation(net, attrmodel, config_data,
                                      CACHE_FOLDER_DIR)
    else:
        aggregate_evaluation(net,
                             attrmodel,
                             config_data,
                             CACHE_FOLDER_DIR,
                             reshape_size=reshape_size,
                             realtime_update=realtime_update,
                             EVALUATE_BRANCH=FIND_OPTIM_BRANCH_MODEL)
Пример #16
0
def measure_model(
    model_version,
    dataset,
    out_folder,
    weights_dir,
    device,
    method=METHODS["gradcam"],
    sample_images=50,
    step=1,
):
    invTrans = get_inverse_normalization_transformation()
    data_dir = os.path.join("data")

    if model_version == "resnet18":
        model = create_resnet18_model(num_of_classes=NUM_OF_CLASSES[dataset])
    elif model_version == "resnet50":
        model = create_resnet50_model(num_of_classes=NUM_OF_CLASSES[dataset])
    elif model_version == "densenet":
        model = create_densenet121_model(
            num_of_classes=NUM_OF_CLASSES[dataset])
    else:
        model = create_efficientnetb0_model(
            num_of_classes=NUM_OF_CLASSES[dataset])

    model.load_state_dict(torch.load(weights_dir))

    # print(model)

    model.eval()
    model.to(device)

    test_dataset = CustomDataset(
        dataset=dataset,
        transformer=get_default_transformation(),
        data_type="test",
        root_dir=data_dir,
        step=step,
    )
    data_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=4)

    try:
        image_ids = random.sample(range(0, test_dataset.__len__()),
                                  sample_images)
    except ValueError:
        raise ValueError(
            f"Image sample number ({sample_images}) exceeded dataset size ({test_dataset.__len__()})."
        )

    classes_map = test_dataset.classes_map

    print(f"Measuring {model_version} on {dataset} dataset, with {method}")
    print("-" * 10)
    pbar = tqdm(total=test_dataset.__len__(), desc="Model test completion")
    multipy_by_inputs = False
    if method == METHODS["ig"]:
        attr_method = IntegratedGradients(model)
        nt_samples = 8
        n_perturb_samples = 3
    if method == METHODS["saliency"]:
        attr_method = Saliency(model)
        nt_samples = 8
        n_perturb_samples = 10
    if method == METHODS["gradcam"]:
        if model_version == "efficientnet":
            attr_method = GuidedGradCam(model, model._conv_stem)
        elif model_version == "densenet":
            attr_method = GuidedGradCam(model, model.features.conv0)
        else:
            attr_method = GuidedGradCam(model, model.conv1)
        nt_samples = 8
        n_perturb_samples = 10
    if method == METHODS["deconv"]:
        attr_method = Deconvolution(model)
        nt_samples = 8
        n_perturb_samples = 10
    if method == METHODS["gradshap"]:
        attr_method = GradientShap(model)
        nt_samples = 8
        if model_version == "efficientnet":
            n_perturb_samples = 3
        elif model_version == "densenet":
            n_perturb_samples = 2
        else:
            n_perturb_samples = 10
    if method == METHODS["gbp"]:
        attr_method = GuidedBackprop(model)
        nt_samples = 8
        n_perturb_samples = 10
    if method == "lime":
        attr_method = Lime(model)
        nt_samples = 8
        n_perturb_samples = 10
        feature_mask = torch.tensor(lime_mask).to(device)
        multipy_by_inputs = True
    if method == METHODS['ig']:
        nt = attr_method
    else:
        nt = NoiseTunnel(attr_method)
    scores = []

    @infidelity_perturb_func_decorator(multipy_by_inputs=multipy_by_inputs)
    def perturb_fn(inputs):
        noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float()
        noise = noise.to(device)
        return inputs - noise

    for input, label in data_loader:
        pbar.update(1)
        inv_input = invTrans(input)
        input = input.to(device)
        input.requires_grad = True
        output = model(input)
        output = F.softmax(output, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)
        prediction_score = prediction_score.cpu().detach().numpy()[0][0]
        pred_label_idx.squeeze_()

        if method == METHODS['gradshap']:
            baseline = torch.randn(input.shape)
            baseline = baseline.to(device)

        if method == "lime":
            attributions = attr_method.attribute(input, target=1, n_samples=50)
        elif method == METHODS['ig']:
            attributions = nt.attribute(
                input,
                target=pred_label_idx,
                n_steps=25,
            )
        elif method == METHODS['gradshap']:
            attributions = nt.attribute(input,
                                        target=pred_label_idx,
                                        baselines=baseline)
        else:
            attributions = nt.attribute(
                input,
                nt_type="smoothgrad",
                nt_samples=nt_samples,
                target=pred_label_idx,
            )

        infid = infidelity(model,
                           perturb_fn,
                           input,
                           attributions,
                           target=pred_label_idx)

        if method == "lime":
            sens = sensitivity_max(
                attr_method.attribute,
                input,
                target=pred_label_idx,
                n_perturb_samples=1,
                n_samples=200,
                feature_mask=feature_mask,
            )
        elif method == METHODS['ig']:
            sens = sensitivity_max(
                nt.attribute,
                input,
                target=pred_label_idx,
                n_perturb_samples=n_perturb_samples,
                n_steps=25,
            )
        elif method == METHODS['gradshap']:
            sens = sensitivity_max(nt.attribute,
                                   input,
                                   target=pred_label_idx,
                                   n_perturb_samples=n_perturb_samples,
                                   baselines=baseline)
        else:
            sens = sensitivity_max(
                nt.attribute,
                input,
                target=pred_label_idx,
                n_perturb_samples=n_perturb_samples,
            )
        inf_value = infid.cpu().detach().numpy()[0]
        sens_value = sens.cpu().detach().numpy()[0]
        if pbar.n in image_ids:
            attr_data = attributions.squeeze().cpu().detach().numpy()
            fig, ax = viz.visualize_image_attr_multiple(
                np.transpose(attr_data, (1, 2, 0)),
                np.transpose(inv_input.squeeze().cpu().detach().numpy(),
                             (1, 2, 0)),
                ["original_image", "heat_map"],
                ["all", "positive"],
                titles=["original_image", "heat_map"],
                cmap=default_cmap,
                show_colorbar=True,
                use_pyplot=False,
                fig_size=(8, 6),
            )
            ax[0].set_xlabel(
                f"Infidelity: {'{0:.6f}'.format(inf_value)}\n Sensitivity: {'{0:.6f}'.format(sens_value)}"
            )
            fig.suptitle(
                f"True: {classes_map[str(label.numpy()[0])][0]}, Pred: {classes_map[str(pred_label_idx.item())][0]}\nScore: {'{0:.4f}'.format(prediction_score)}",
                fontsize=16,
            )
            fig.savefig(
                os.path.join(
                    out_folder,
                    f"{str(pbar.n)}-{classes_map[str(label.numpy()[0])][0]}-{classes_map[str(pred_label_idx.item())][0]}.png",
                ))
            plt.close(fig)
            # if pbar.n > 25:
            #     break

        scores.append([inf_value, sens_value])
    pbar.close()

    np.savetxt(
        os.path.join(out_folder, f"{model_version}-{dataset}-{method}.csv"),
        np.array(scores),
        delimiter=",",
        header="infidelity,sensitivity",
    )

    print(f"Artifacts stored at {out_folder}")
Пример #17
0
def compute_guided_backprop(model, preprocessed_image, label, saliency_layer=None):
    saliency = GuidedBackprop(model).attribute(preprocessed_image, label)
    grad = saliency.detach().cpu().clone().numpy().squeeze()
    return grad
Пример #18
0
 Deconvolution.get_name():
 ConfigParameters(params={}),
 Occlusion.get_name():
 ConfigParameters(
     params={
         "sliding_window_shapes": StrConfig(value=""),
         "strides": StrConfig(value=""),
         "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
     },
     post_process={
         "sliding_window_shapes": _str_to_tuple,
         "strides": _str_to_tuple,
         "perturbations_per_eval": int,
     },
 ),
 GuidedBackprop.get_name():
 ConfigParameters(params={}),
 InputXGradient.get_name():
 ConfigParameters(params={}),
 Saliency.get_name():
 ConfigParameters(
     params={"abs": StrEnumConfig(limit=["True", "False"], value="True")},
     post_process={"abs": _str_to_bool}),
 # Won't work as Relu is being used in multiple places (same layer can't be shared)
 # DeepLift.get_name(): ConfigParameters(
 #     params={}
 # ),
 LayerIntegratedGradients.get_name():
 ConfigParameters(
     params={
         "n_steps":
Пример #19
0
# Convert X and y from numpy arrays to Torch Tensors
X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0)
y_tensor = torch.LongTensor(y)

# Computing Guided GradCam
gc_result = GuidedGradCam(model, model.features[3])
attr_ggc = compute_attributions(gc_result, X_tensor, target=y_tensor)
attr_ggc = attr_ggc.abs()
attr_ggc, i = torch.max(attr_ggc, dim=1)
attr_ggc = attr_ggc.unsqueeze(0)
visualize_attr_maps('visualization/Captum_Guided_GradCam.png', X, y,
                    class_names, attr_ggc, ['GuidedGradCam'],
                    lambda attr: attr.detach().numpy())

# # Computing Guided BackProp
gbp_result = GuidedBackprop(model)
attr_gbp = compute_attributions(gbp_result, X_tensor, target=y_tensor)
attr_gbp = attr_gbp.abs()
attr_gbp, i = torch.max(attr_gbp, dim=1)
attr_gbp = attr_gbp.unsqueeze(0)
visualize_attr_maps('visualization/Captum_Guided_BackProp.png', X, y,
                    class_names, attr_gbp, ['GuidedBackprop'],
                    lambda attr: attr.detach().numpy())

# Try out different layers
# and see observe how the attributions change
layer = model.features[3]

layer_act = LayerActivation(model, layer)
layer_act_attr = compute_attributions(layer_act, X_tensor)
layer_act_attr_sum = layer_act_attr.mean(axis=1, keepdim=True)
Пример #20
0
# interpret.interpret_image(image=image_tensor
#                           , target_class=target
#                           , result_dir="./")


# %%
from misc_functions import (get_example_params,
                            convert_to_grayscale,
                            save_gradient_images,
                            get_positive_negative_saliency)
from captum.attr import GuidedBackprop
from parkinsonsNet import Network

model = torch.load("/home/anasa2/pre_trained/parkinsonsNet-rest_mpower-rest.pth", map_location="cpu")  

algo = GuidedBackprop(model)

# %%
import numpy as np

input = torch.randn(1, 3, 4000, requires_grad=True)
attribution = algo.attribute(input, target=0).detach().cpu().numpy()[0]
attribution = np.round(convert_to_grayscale(attribution* input.detach().numpy()[0]))
save_gradient_images(attribution, 'signal_color')

# %%
import pandas as pd
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import math
Пример #21
0
 def __init__(self, trainer):
     CaptumDerivative.__init__(self, trainer)
     GuidedBackprop.__init__(self, self.trainer)
Пример #22
0
    def __init__(self, model, X, y, gpu, channel_labels):
        """

        Args:
            model:
            dataset:
            gpu:
            channel_labels:
        """
        # data_loader = DataLoader(dataset, batch_size=len(dataset))
        # itr = iter(data_loader)
        self.X, self.y = X, y
        
        self.algo = GuidedBackprop(model)
        
        
        self.model = model.eval()

        if len(self.X.shape) != 3:
            raise IllegalDataDimensionException("Expected data to have dimensions 3 but got dimension ",
                                                str(len(self.X.shape)))

        if len(self.X.shape) != len(channel_labels):
            raise IllegalDataDimensionException("Expected channel_labels to contain ", str(len(self.X.shape)),
                                                " labels but got ", str(len(channel_labels)), " labels instead.")

        self.channel_labels = channel_labels

        if gpu == -1:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda:' + str(gpu))
            
        model.to(self.device)

        counter = 0
        self.modules = {}
        for top, module in model.named_children():
            if type(module) == torch.nn.Sequential:
                self.modules.update({"" + top: module})
                counter += 1

        print(f"Total Interpretable layers: {counter}")

        self.X.to(self.device)

        layer_map = {0: self.X.flatten(start_dim=1)}
        index = 1
        output = self.X

        for i in self.modules:
            layer = self.modules[i]
            layer.eval().to(self.device)
            output = layer(output)
            layer_map.update({index: output.flatten(start_dim=1)})
            index += 1

        assert counter == len(layer_map) - 1
        print("Intermediate Forward Pass Through " + str(len(layer_map) - 1) + " Layers")

        print("Generating PCA...")
        self.pca_layer_result = []

        self.pca = PCA(n_components=3)
        self.pca_result = self.pca.fit_transform(layer_map[len(layer_map) - 1].detach().numpy())
        for i in range(len(channel_labels)):
            self.pca_layer_result.append(self.pca_result[:,i])
        
        print("Generation Complete")
Пример #23
0
def generate_saliency(model_path, saliency_path, saliency, aggregation):
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_args = Namespace(**checkpoint['args'])
    if args.model == 'lstm':
        model = LSTM_MODEL(tokenizer,
                           model_args,
                           n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])
    elif args.model == 'trans':
        transformer_config = BertConfig.from_pretrained(
            'bert-base-uncased', num_labels=model_args.labels)
        model_cp = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', config=transformer_config).to(device)
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model_cp.load_state_dict(checkpoint['model'])
        model = BertModelWrapper(model_cp)
    else:
        model = CNN_MODEL(tokenizer,
                          model_args,
                          n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])

    model.train()

    pad_to_max = False
    if saliency == 'deeplift':
        ablator = DeepLift(model)
    elif saliency == 'guided':
        ablator = GuidedBackprop(model)
    elif saliency == 'sal':
        ablator = Saliency(model)
    elif saliency == 'inputx':
        ablator = InputXGradient(model)
    elif saliency == 'occlusion':
        ablator = Occlusion(model)

    coll_call = get_collate_fn(dataset=args.dataset, model=args.model)

    return_attention_masks = args.model == 'trans'

    collate_fn = partial(coll_call,
                         tokenizer=tokenizer,
                         device=device,
                         return_attention_masks=return_attention_masks,
                         pad_to_max_length=pad_to_max)
    test = get_dataset(path=args.dataset_dir,
                       mode=args.split,
                       dataset=args.dataset)
    batch_size = args.batch_size if args.batch_size != None else \
        model_args.batch_size
    test_dl = DataLoader(batch_size=batch_size,
                         dataset=test,
                         shuffle=False,
                         collate_fn=collate_fn)

    # PREDICTIONS
    predictions_path = model_path + '.predictions'
    if not os.path.exists(predictions_path):
        predictions = defaultdict(lambda: [])
        for batch in tqdm(test_dl, desc='Running test prediction... '):
            if args.model == 'trans':
                logits = model(batch[0],
                               attention_mask=batch[1],
                               labels=batch[2].long())
            else:
                logits = model(batch[0])
            logits = logits.detach().cpu().numpy().tolist()
            predicted = np.argmax(np.array(logits), axis=-1)
            predictions['class'] += predicted.tolist()
            predictions['logits'] += logits

        with open(predictions_path, 'w') as out:
            json.dump(predictions, out)

    # COMPUTE SALIENCY
    if saliency != 'occlusion':
        embedding_layer_name = 'model.bert.embeddings' if args.model == \
                                                          'trans' else \
            'embedding'
        interpretable_embedding = configure_interpretable_embedding_layer(
            model, embedding_layer_name)

    class_attr_list = defaultdict(lambda: [])
    token_ids = []
    saliency_flops = []

    for batch in tqdm(test_dl, desc='Running Saliency Generation...'):
        if args.model == 'cnn':
            additional = None
        elif args.model == 'trans':
            additional = (batch[1], batch[2])
        else:
            additional = batch[-1]

        token_ids += batch[0].detach().cpu().numpy().tolist()
        if saliency != 'occlusion':
            input_embeddings = interpretable_embedding.indices_to_embeddings(
                batch[0])

        if not args.no_time:
            high.start_counters([
                events.PAPI_FP_OPS,
            ])
        for cls_ in range(checkpoint['args']['labels']):
            if saliency == 'occlusion':
                attributions = ablator.attribute(
                    batch[0],
                    sliding_window_shapes=(args.sw, ),
                    target=cls_,
                    additional_forward_args=additional)
            else:
                attributions = ablator.attribute(
                    input_embeddings,
                    target=cls_,
                    additional_forward_args=additional)

            attributions = summarize_attributions(
                attributions, type=aggregation, model=model,
                tokens=batch[0]).detach().cpu().numpy().tolist()
            class_attr_list[cls_] += [[_li for _li in _l]
                                      for _l in attributions]

        if not args.no_time:
            saliency_flops.append(
                sum(high.stop_counters()) / batch[0].shape[0])

    if saliency != 'occlusion':
        remove_interpretable_embedding_layer(model, interpretable_embedding)

    # SERIALIZE
    print('Serializing...', flush=True)
    with open(saliency_path, 'w') as out:
        for instance_i, _ in enumerate(test):
            saliencies = []
            for token_i, token_id in enumerate(token_ids[instance_i]):
                token_sal = {'token': tokenizer.ids_to_tokens[token_id]}
                for cls_ in range(checkpoint['args']['labels']):
                    token_sal[int(
                        cls_)] = class_attr_list[cls_][instance_i][token_i]
                saliencies.append(token_sal)

            out.write(json.dumps({'tokens': saliencies}) + '\n')
            out.flush()

    return saliency_flops