Ejemplo n.º 1
0
def explain_adata(model,
                  adata,
                  baseline_adata,
                  target,
                  device='cpu',
                  method='integrated'):
    try:
        a_dl = get_prediction_dataloader(adata, model.genes)
        b_dl = get_prediction_dataloader(baseline_adata, model.genes)
        a, ab = a_dl.dataset.adata, b_dl.dataset.adata
    except ZeroDivisionError:
        logging.info(
            f'explain for target {target} with size of {adata.shape} failed in normalization'
        )
        return None

    X = a.X.toarray() if 'sparse' in str(type(a.X)).lower() else a.X
    baseline_X = ab.X.toarray() if 'sparse' in str(type(
        ab.X)).lower() else ab.X
    inputs = torch.tensor(X)
    baseline = torch.tensor(baseline_X)

    inputs, baseline = inputs.to(device), baseline.to(device)

    attr_model = AttributionWrapper(model.to(device))

    if not isinstance(target, int) and target != 'all':
        target = model.classes.index(target)

    if method == 'deeplift':
        dls = DeepLiftShap(attr_model)
    elif method == 'integrated':
        baseline = torch.mean(baseline, dim=0).unsqueeze(dim=0)
        dls = IntegratedGradients(attr_model)

    if target != 'all':
        attrs, _ = dls.attribute(inputs,
                                 baseline,
                                 target=target,
                                 return_convergence_delta=True)
        return pd.DataFrame(data=attrs.detach().cpu().numpy(),
                            columns=model.genes,
                            index=a.obs.index.to_list())
    else:
        attrs = {}
        for i in range(len(model.classes)):
            ct = model.classes[i]
            logging.info(f'calculating feature importances for {ct}')
            attr, _ = dls.attribute(inputs,
                                    baseline,
                                    target=i,
                                    return_convergence_delta=True)
            attrs[model.classes[i]] = pd.DataFrame(
                data=attr.detach().cpu().numpy(),
                columns=model.genes,
                index=a.obs.index.to_list())
        return attrs
Ejemplo n.º 2
0
def compute_deep_sharp(model, preprocessed_image, label, baseline=None):
    if baseline == "zero":
        base_distribution = preprocessed_image.new_zeros(
            (10, ) + preprocessed_image.shape[1:])
    else:
        raise NotImplementedError
    saliency = DeepLiftShap(model).attribute(preprocessed_image,
                                             target=label,
                                             baselines=base_distribution)
    grad = saliency.detach().cpu().clone().numpy().squeeze()
    return grad
Ejemplo n.º 3
0
class Explainer():
    def __init__(self, model):
        self.model = model
        self.explain = DeepLiftShap(model)
        return

    def get_attribution_map(self, img, target=None):
        '''
        input:
        img: batch X channels X height X width [BCHW], torch Tensor

        output:
        attribution_map: batch X height X width,numpy
        '''
        baseline_dist = torch.randn_like(img) * 0.001
        if target is None:
            target = torch.argmax(self.model(img), 1)
        attributions, delta = self.explain.attribute(
            img, baseline_dist, target=target, return_convergence_delta=True)
        return attributions
Ejemplo n.º 4
0
def get_explanation(generated_data, discriminator, prediction, XAItype="shap", cuda=True, trained_data=None,
                    data_type="mnist") -> None:
    """
    This function calculates the explanation for given generated images using the desired xAI systems and the
    :param generated_data: data created by the generator
    :type generated_data: torch.Tensor
    :param discriminator: the discriminator model
    :type discriminator: torch.nn.Module
    :param prediction: tensor of predictions by the discriminator on the generated data
    :type prediction: torch.Tensor
    :param XAItype: the type of xAI system to use. One of ("shap", "lime", "saliency")
    :type XAItype: str
    :param cuda: whether to use gpu
    :type cuda: bool
    :param trained_data: a batch from the dataset
    :type trained_data: torch.Tensor
    :param data_type: the type of the dataset used. One of ("cifar", "mnist", "fmnist")
    :type data_type: str
    :return:
    :rtype:
    """

    # initialize temp values to all 1s
    temp = values_target(size=generated_data.size(), value=1.0, cuda=cuda)

    # mask values with low prediction
    mask = (prediction < 0.5).view(-1)
    indices = (mask.nonzero(as_tuple=False)).detach().cpu().numpy().flatten().tolist()

    data = generated_data[mask, :]

    if len(indices) > 1:
        if XAItype == "saliency":
            for i in range(len(indices)):
                explainer = Saliency(discriminator)
                temp[indices[i], :] = explainer.attribute(data[i, :].detach().unsqueeze(0))

        elif XAItype == "shap":
            for i in range(len(indices)):
                explainer = DeepLiftShap(discriminator)
                temp[indices[i], :] = explainer.attribute(data[i, :].detach().unsqueeze(0), trained_data, target=0)

        elif XAItype == "lime":
            explainer = lime_image.LimeImageExplainer()
            global discriminatorLime
            discriminatorLime = deepcopy(discriminator)
            discriminatorLime.cpu()
            discriminatorLime.eval()
            for i in range(len(indices)):
                if data_type == "cifar":
                    tmp = data[i, :].detach().cpu().numpy()
                    tmp = np.reshape(tmp, (32, 32, 3)).astype(np.double)
                    exp = explainer.explain_instance(tmp, batch_predict_cifar, num_samples=100)
                else:
                    tmp = data[i, :].squeeze().detach().cpu().numpy().astype(np.double)
                    exp = explainer.explain_instance(tmp, batch_predict, num_samples=100)
                _, mask = exp.get_image_and_mask(exp.top_labels[0], positive_only=False, negative_only=False)
                temp[indices[i], :] = torch.tensor(mask.astype(np.float))
            del discriminatorLime
        else:
            raise Exception("wrong xAI type given")

    if cuda:
        temp = temp.cuda()
    set_values(normalize_vector(temp))
Ejemplo n.º 5
0
    def __init__(self, predictor: Predictor):

        CaptumAttribution.__init__(self, 'dlshap', predictor)

        self.submodel = self.predictor._model.captum_sub_model()
        DeepLiftShap.__init__(self, self.submodel)
def main(args):

    train_loader, test_loader = data_generator(args.data_dir,1)

    for m in range(len(models)):

        model_name = "model_{}_NumFeatures_{}".format(models[m],args.NumFeatures)
        model_filename = args.model_dir + 'm_' + model_name + '.pt'
        pretrained_model = torch.load(open(model_filename, "rb"),map_location=device) 
        pretrained_model.to(device)



        if(args.GradFlag):
            Grad = Saliency(pretrained_model)
        if(args.IGFlag):
            IG = IntegratedGradients(pretrained_model)
        if(args.DLFlag):
            DL = DeepLift(pretrained_model)
        if(args.GSFlag):
            GS = GradientShap(pretrained_model)
        if(args.DLSFlag):
            DLS = DeepLiftShap(pretrained_model)                 
        if(args.SGFlag):
            Grad_ = Saliency(pretrained_model)
            SG = NoiseTunnel(Grad_)
        if(args.ShapleySamplingFlag):
            SS = ShapleyValueSampling(pretrained_model)
        if(args.GSFlag):
            FP = FeaturePermutation(pretrained_model)
        if(args.FeatureAblationFlag):
            FA = FeatureAblation(pretrained_model)         
        if(args.OcclusionFlag):
            OS = Occlusion(pretrained_model)

        timeMask=np.zeros((args.NumTimeSteps, args.NumFeatures),dtype=int)
        featureMask=np.zeros((args.NumTimeSteps, args.NumFeatures),dtype=int)
        for i in  range (args.NumTimeSteps):
            timeMask[i,:]=i

        for i in  range (args.NumTimeSteps):
            featureMask[:,i]=i

        indexes = [[] for i in range(5,10)]
        for i ,(data, target) in enumerate(test_loader):
            if(target==5 or target==6 or target==7 or target==8 or target==9):
                index=target-5

                if(len(indexes[index])<1):
                    indexes[index].append(i)
        for j, index in enumerate(indexes):
            print(index)
        # indexes = [[21],[17],[84],[9]]

        for j, index in enumerate(indexes):
            print("Getting Saliency for number", j+1)
            for i, (data, target) in enumerate(test_loader):
                if(i in index):
                        
                    labels =  target.to(device)
             
                    input = data.reshape(-1, args.NumTimeSteps, args.NumFeatures).to(device)
                    input = Variable(input,  volatile=False, requires_grad=True)

                    baseline_single=torch.Tensor(np.random.random(input.shape)).to(device)
                    baseline_multiple=torch.Tensor(np.random.random((input.shape[0]*5,input.shape[1],input.shape[2]))).to(device)
                    inputMask= np.zeros((input.shape))
                    inputMask[:,:,:]=timeMask
                    inputMask =torch.Tensor(inputMask).to(device)
                    mask_single= torch.Tensor(timeMask).to(device)
                    mask_single=mask_single.reshape(1,args.NumTimeSteps, args.NumFeatures).to(device)

                    Data=data.reshape(args.NumTimeSteps, args.NumFeatures).data.cpu().numpy()
                    
                    target_=int(target.data.cpu().numpy()[0])

                    plotExampleBox(Data,args.Graph_dir+'Sample_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.GradFlag):
                        attributions = Grad.attribute(input, \
                                                      target=labels)
                        
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_Grad_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(Grad,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=None)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_Grad_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)



                    if(args.IGFlag):
                        attributions = IG.attribute(input,  \
                                                    baselines=baseline_single, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_IG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(IG,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_IG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.DLFlag):
                        attributions = DL.attribute(input,  \
                                                    baselines=baseline_single, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_DL_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(DL,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_DL_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.GSFlag):

                        attributions = GS.attribute(input,  \
                                                    baselines=baseline_multiple, \
                                                    stdevs=0.09,\
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_GS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)

 
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(GS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_multiple)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_GS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                    if(args.DLSFlag):

                        attributions = DLS.attribute(input,  \
                                                    baselines=baseline_multiple, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_DLS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(DLS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_multiple)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_DLS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)



                    if(args.SGFlag):
                        attributions = SG.attribute(input, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_SG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(SG,input, args.NumFeatures,args.NumTimeSteps, labels)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_SG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                    if(args.ShapleySamplingFlag):
                        attributions = SS.attribute(input, \
                                        baselines=baseline_single, \
                                        target=labels,\
                                        feature_mask=inputMask)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_SVS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(SS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single,hasFeatureMask=inputMask)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_SVS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                    # if(args.FeaturePermutationFlag):
                    #     attributions = FP.attribute(input, \
                    #                     target=labels),
                    #                     # perturbations_per_eval= 1,\
                    #                     # feature_mask=mask_single)
                    #     saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                    #     plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FP',greyScale=True)


                    if(args.FeatureAblationFlag):
                        attributions = FA.attribute(input, \
                                        target=labels)
                                        # perturbations_per_eval= input.shape[0],\
                                        # feature_mask=mask_single)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FA_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(FA,input, args.NumFeatures,args.NumTimeSteps, labels)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_FA_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)

                    if(args.OcclusionFlag):
                        attributions = OS.attribute(input, \
                                        sliding_window_shapes=(1,int(args.NumFeatures/10)),
                                        target=labels,
                                        baselines=baseline_single)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FO_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(OS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single,hasSliding_window_shapes= (1,int(args.NumFeatures/10)))
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_FO_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
Ejemplo n.º 7
0
def run_saliency_methods(saliency_methods,
                         pretrained_model,
                         test_shape,
                         train_loader,
                         test_loader,
                         device,
                         model_type,
                         model_name,
                         saliency_dir,
                         tsr_graph_dir=None,
                         tsr_inputs_to_graph=()):
    _, num_timesteps, num_features = test_shape

    run_grad = "Grad" in saliency_methods
    run_grad_tsr = "Grad_TSR" in saliency_methods
    run_ig = "IG" in saliency_methods
    run_ig_tsr = "IG_TSR" in saliency_methods
    run_dl = "DL" in saliency_methods
    run_gs = "GS" in saliency_methods
    run_dls = "DLS" in saliency_methods
    run_dls_tsr = "DLS_TSR" in saliency_methods
    run_sg = "SG" in saliency_methods
    run_shapley_sampling = "ShapleySampling" in saliency_methods
    run_feature_permutation = "FeaturePermutation" in saliency_methods
    run_feature_ablation = "FeatureAblation" in saliency_methods
    run_occlusion = "Occlusion" in saliency_methods
    run_fit = "FIT" in saliency_methods
    run_ifit = "IFIT" in saliency_methods
    run_wfit = "WFIT" in saliency_methods
    run_iwfit = "IWFIT" in saliency_methods

    if run_grad or run_grad_tsr:
        Grad = Saliency(pretrained_model)
    if run_grad:
        rescaledGrad = np.zeros(test_shape)
    if run_grad_tsr:
        rescaledGrad_TSR = np.zeros(test_shape)

    if run_ig or run_ig_tsr:
        IG = IntegratedGradients(pretrained_model)
    if run_ig:
        rescaledIG = np.zeros(test_shape)
    if run_ig_tsr:
        rescaledIG_TSR = np.zeros(test_shape)

    if run_dl:
        rescaledDL = np.zeros(test_shape)
        DL = DeepLift(pretrained_model)

    if run_gs:
        rescaledGS = np.zeros(test_shape)
        GS = GradientShap(pretrained_model)

    if run_dls or run_dls_tsr:
        DLS = DeepLiftShap(pretrained_model)
    if run_dls:
        rescaledDLS = np.zeros(test_shape)
    if run_dls_tsr:
        rescaledDLS_TSR = np.zeros(test_shape)

    if run_sg:
        rescaledSG = np.zeros(test_shape)
        Grad_ = Saliency(pretrained_model)
        SG = NoiseTunnel(Grad_)

    if run_shapley_sampling:
        rescaledShapleySampling = np.zeros(test_shape)
        SS = ShapleyValueSampling(pretrained_model)

    if run_gs:
        rescaledFeaturePermutation = np.zeros(test_shape)
        FP = FeaturePermutation(pretrained_model)

    if run_feature_ablation:
        rescaledFeatureAblation = np.zeros(test_shape)
        FA = FeatureAblation(pretrained_model)

    if run_occlusion:
        rescaledOcclusion = np.zeros(test_shape)
        OS = Occlusion(pretrained_model)

    if run_fit:
        rescaledFIT = np.zeros(test_shape)
        FIT = FITExplainer(pretrained_model, ft_dim_last=True)
        generator = JointFeatureGenerator(num_features, data='none')
        # TODO: Increase epochs
        FIT.fit_generator(generator, train_loader, test_loader, n_epochs=300)

    if run_ifit:
        rescaledIFIT = np.zeros(test_shape)
    if run_wfit:
        rescaledWFIT = np.zeros(test_shape)
    if run_iwfit:
        rescaledIWFIT = np.zeros(test_shape)

    idx = 0
    mask = np.zeros((num_timesteps, num_features), dtype=int)
    for i in range(num_timesteps):
        mask[i, :] = i

    for i, (samples, labels) in enumerate(test_loader):
        input = samples.reshape(-1, num_timesteps, num_features).to(device)
        input = Variable(input, volatile=False, requires_grad=True)

        batch_size = input.shape[0]
        baseline_single = torch.from_numpy(np.random.random(
            input.shape)).to(device)
        baseline_multiple = torch.from_numpy(
            np.random.random((input.shape[0] * 5, input.shape[1],
                              input.shape[2]))).to(device)
        inputMask = np.zeros((input.shape))
        inputMask[:, :, :] = mask
        inputMask = torch.from_numpy(inputMask).to(device)
        mask_single = torch.from_numpy(mask).to(device)
        mask_single = mask_single.reshape(1, num_timesteps,
                                          num_features).to(device)
        labels = torch.tensor(labels.int().tolist()).to(device)

        if run_grad:
            attributions = Grad.attribute(input, target=labels)
            rescaledGrad[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)
        if run_grad_tsr:
            rescaledGrad_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                Grad,
                input,
                labels,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_Grad_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_ig:
            attributions = IG.attribute(input,
                                        baselines=baseline_single,
                                        target=labels)
            rescaledIG[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)
        if run_ig_tsr:
            rescaledIG_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                IG,
                input,
                labels,
                baseline=baseline_single,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_IG_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_dl:
            attributions = DL.attribute(input,
                                        baselines=baseline_single,
                                        target=labels)
            rescaledDL[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_gs:
            attributions = GS.attribute(input,
                                        baselines=baseline_multiple,
                                        stdevs=0.09,
                                        target=labels)
            rescaledGS[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_dls:
            attributions = DLS.attribute(input,
                                         baselines=baseline_multiple,
                                         target=labels)
            rescaledDLS[idx:idx +
                        batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                            num_timesteps, num_features, attributions)
        if run_dls_tsr:
            rescaledDLS_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                DLS,
                input,
                labels,
                baseline=baseline_multiple,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_DLS_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_sg:
            attributions = SG.attribute(input, target=labels)
            rescaledSG[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_shapley_sampling:
            attributions = SS.attribute(input,
                                        baselines=baseline_single,
                                        target=labels,
                                        feature_mask=inputMask)
            rescaledShapleySampling[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_feature_permutation:
            attributions = FP.attribute(input,
                                        target=labels,
                                        perturbations_per_eval=input.shape[0],
                                        feature_mask=mask_single)
            rescaledFeaturePermutation[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_feature_ablation:
            attributions = FA.attribute(input, target=labels)
            # perturbations_per_eval= input.shape[0],\
            # feature_mask=mask_single)
            rescaledFeatureAblation[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_occlusion:
            attributions = OS.attribute(input,
                                        sliding_window_shapes=(1,
                                                               num_features),
                                        target=labels,
                                        baselines=baseline_single)
            rescaledOcclusion[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_fit:
            attributions = torch.from_numpy(FIT.attribute(input, labels))
            rescaledFIT[idx:idx +
                        batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                            num_timesteps, num_features, attributions)

        if run_ifit:
            attributions = torch.from_numpy(
                inverse_fit_attribute(input,
                                      pretrained_model,
                                      ft_dim_last=True))
            rescaledIFIT[idx:idx + batch_size, :, :] = attributions

        if run_wfit:
            attributions = torch.from_numpy(
                wfit_attribute(input,
                               pretrained_model,
                               N=test_shape[1],
                               ft_dim_last=True,
                               single_label=True))
            rescaledWFIT[idx:idx + batch_size, :, :] = attributions

        if run_iwfit:
            attributions = torch.from_numpy(
                wfit_attribute(input,
                               pretrained_model,
                               N=test_shape[1],
                               ft_dim_last=True,
                               single_label=True,
                               inverse=True))
            rescaledIWFIT[idx:idx + batch_size, :, :] = attributions

        idx += batch_size

    if run_grad:
        print("Saving Grad", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_Grad_rescaled",
            rescaledGrad)
    if run_grad_tsr:
        print("Saving Grad_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_Grad_TSR_rescaled", rescaledGrad_TSR)

    if run_ig:
        print("Saving IG", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_IG_rescaled",
                rescaledIG)
    if run_ig_tsr:
        print("Saving IG_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IG_TSR_rescaled",
            rescaledIG_TSR)

    if run_dl:
        print("Saving DL", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_DL_rescaled",
                rescaledDL)

    if run_gs:
        print("Saving GS", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_GS_rescaled",
                rescaledGS)

    if run_dls:
        print("Saving DLS", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_DLS_rescaled",
                rescaledDLS)
    if run_dls_tsr:
        print("Saving DLS_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_DLS_TSR_rescaled",
            rescaledDLS_TSR)

    if run_sg:
        print("Saving SG", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_SG_rescaled",
                rescaledSG)

    if run_shapley_sampling:
        print("Saving ShapleySampling", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_ShapleySampling_rescaled", rescaledShapleySampling)

    if run_feature_permutation:
        print("Saving FeaturePermutation", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_FeaturePermutation_rescaled", rescaledFeaturePermutation)

    if run_feature_ablation:
        print("Saving FeatureAblation", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_FeatureAblation_rescaled", rescaledFeatureAblation)

    if run_occlusion:
        print("Saving Occlusion", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_Occlusion_rescaled", rescaledOcclusion)

    if run_fit:
        print("Saving FIT", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_FIT_rescaled",
                rescaledFIT)

    if run_ifit:
        print("Saving IFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IFIT_rescaled",
            rescaledIFIT)

    if run_wfit:
        print("Saving WFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_WFIT_rescaled",
            rescaledWFIT)

    if run_iwfit:
        print("Saving IWFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IWFIT_rescaled",
            rescaledIWFIT)
Ejemplo n.º 8
0
agg_w = learned_weights[0].T
for k in range(1, len(learned_weights)):
  agg_w = np.dot(agg_w, learned_weights[k].T)

print(agg_w.shape)
np.save('agg_cancer_model_weights_sdae.npy', agg_w)
"""

print(
    "< =============== Run DeepLIFT to measure gene contributions ====================== >"
)

dl_loader = DataLoader(x_tr, batch_size=x_tr.shape[0])
model.to(torch.device("cpu"))
deeplift = DeepLift(model)
deeplift_shap = DeepLiftShap(model)
deeplift_baseline = deeplift_baseline.resize(1, deeplift_baseline.shape[0])

attribution_brca = 0
attribution_ucec = 0
attribution_kirc = 0
attribution_luad = 0
attribution_skcm = 0

total = 0

for i, inputs in enumerate(dl_loader, 0):
    attribution_brca = torch.abs(
        deeplift.attribute(inputs, target=0, baselines=deeplift_baseline))
    attribution_ucec = torch.abs(
        deeplift.attribute(inputs, target=1, baselines=deeplift_baseline))
Ejemplo n.º 9
0
 def __init__(self, model):
     self.model = model
     self.explain = DeepLiftShap(model)
     return
Ejemplo n.º 10
0
def get_grad_imp(model,
                 X,
                 y=None,
                 mode='grad',
                 return_y=False,
                 clip=False,
                 baselines=None):
    X.requires_grad_(True)
    #     X = X.cuda()

    if mode in ['grad']:
        logits = model(X)
        if y is None:
            y = logits.argmax(dim=1)

        attributions = torch.autograd.grad(
            logits[torch.arange(len(logits)), y].sum(), X)[0].detach()
    else:
        if y is None:
            with torch.no_grad():
                logits = model(X)
                y = logits.argmax(dim=1)

        if mode == 'deeplift':
            dl = DeepLift(model)

            attributions = dl.attribute(inputs=X, baselines=0., target=y)
            attributions = attributions.detach()
#             attributions = (attributions.detach() ** 2).sum(dim=1, keepdim=True)
#         elif mode in ['deepliftshap', 'deepliftshap_mean']:
        elif mode in ['deepliftshap']:
            dl = DeepLiftShap(model)
            attributions = []
            for idx in range(0, len(X), 2):
                the_x, the_y = X[idx:(idx + 2)], y[idx:(idx + 2)]

                attribution = dl.attribute(inputs=the_x,
                                           baselines=baselines,
                                           target=the_y)
                attributions.append(attribution.detach())
            attributions = torch.cat(attributions, dim=0)


#             attributions = dl.attribute(inputs=X, baselines=baselines, target=y).detach()
#             if mode == 'deepliftshap':
#                 attributions = (attributions ** 2).sum(dim=1, keepdim=True)
#             else:
#                 attributions = (attributions).mean(dim=1, keepdim=True)
        elif mode in ['gradcam']:
            orig_lgc = LayerGradCam(model, model.body[0])
            attributions = orig_lgc.attribute(X, target=y)

            attributions = F.interpolate(attributions,
                                         size=X.shape[-2:],
                                         mode='bilinear')
        else:
            raise NotImplementedError(f'${mode} is not specified.')

    # Do clipping!
    if clip:
        attributions = myclip(attributions)

    X.requires_grad_(False)
    if not return_y:
        return attributions
    return attributions, y
Ejemplo n.º 11
0
def evaluation_ten_classes(initiate_or_load_model,
                           config_data,
                           singleton_scope=False,
                           reshape_size=None,
                           FIND_OPTIM_BRANCH_MODEL=False,
                           realtime_update=False,
                           ALLOW_ADHOC_NOPTIM=False):
    from pipeline.training.training_utils import prepare_save_dirs
    xai_mode = config_data['xai_mode']
    MODEL_DIR, INFO_DIR, CACHE_FOLDER_DIR = prepare_save_dirs(config_data)

    ############################
    VERBOSE = 0
    ############################

    if not FIND_OPTIM_BRANCH_MODEL:
        print(
            'Using the following the model from (only) continuous training for xai evaluation [%s]'
            % (str(xai_mode)))
        net, evaluator = initiate_or_load_model(MODEL_DIR,
                                                INFO_DIR,
                                                config_data,
                                                verbose=VERBOSE)
    else:
        BRANCH_FOLDER_DIR = MODEL_DIR[:MODEL_DIR.find('.model')] + '.%s' % (
            str(config_data['branch_name_label']))
        BRANCH_MODEL_DIR = os.path.join(
            BRANCH_FOLDER_DIR,
            '%s.%s.model' % (str(config_data['model_name']),
                             str(config_data['branch_name_label'])))
        # BRANCH_MODEL_DIR = MODEL_DIR[:MODEL_DIR.find('.model')] + '.%s.model'%(str(config_data['branch_name_label']))

        if ALLOW_ADHOC_NOPTIM:  # this is intended only for debug runs
            print('<< [EXY1] ALLOWING ADHOC NOPTIM >>')
            import shutil
            shutil.copyfile(BRANCH_MODEL_DIR, BRANCH_MODEL_DIR + '.noptim')

        if os.path.exists(BRANCH_MODEL_DIR + '.optim'):
            BRANCH_MODEL_DIR = BRANCH_MODEL_DIR + '.optim'
            print(
                '  Using the OPTIMIZED branch model for [%s] xai evaluation: %s'
                % (str(xai_mode), str(BRANCH_MODEL_DIR)))
        elif os.path.exists(BRANCH_MODEL_DIR + '.noptim'):
            BRANCH_MODEL_DIR = BRANCH_MODEL_DIR + '.noptim'
            print(
                '  Using the partially optimized branch model for [%s] xai evaluation: %s'
                % (str(xai_mode), str(BRANCH_MODEL_DIR)))
        else:
            raise RuntimeError(
                'Attempting to find .optim or .noptim model, but not found.')
        if VERBOSE >= 250:
            print(
                '  """You may see a warning by pytorch for ReLu backward hook. It has been fixed externally, so you can ignore it."""'
            )
        net, evaluator = initiate_or_load_model(BRANCH_MODEL_DIR,
                                                INFO_DIR,
                                                config_data,
                                                verbose=VERBOSE)

    if xai_mode == 'Saliency': attrmodel = Saliency(net)
    elif xai_mode == 'IntegratedGradients':
        attrmodel = IntegratedGradients(net)
    elif xai_mode == 'InputXGradient':
        attrmodel = InputXGradient(net)
    elif xai_mode == 'DeepLift':
        attrmodel = DeepLift(net)
    elif xai_mode == 'GuidedBackprop':
        attrmodel = GuidedBackprop(net)
    elif xai_mode == 'GuidedGradCam':
        attrmodel = GuidedGradCam(net, net.select_first_layer())  # first layer
    elif xai_mode == 'Deconvolution':
        attrmodel = Deconvolution(net)
    elif xai_mode == 'GradientShap':
        attrmodel = GradientShap(net)
    elif xai_mode == 'DeepLiftShap':
        attrmodel = DeepLiftShap(net)
    else:
        raise RuntimeError('No valid attribution selected.')

    if singleton_scope:  # just to observe a single datapoint, mostly for debugging
        singleton_scope_oberservation(net, attrmodel, config_data,
                                      CACHE_FOLDER_DIR)
    else:
        aggregate_evaluation(net,
                             attrmodel,
                             config_data,
                             CACHE_FOLDER_DIR,
                             reshape_size=reshape_size,
                             realtime_update=realtime_update,
                             EVALUATE_BRANCH=FIND_OPTIM_BRANCH_MODEL)
def main(args, DatasetsTypes, DataGenerationTypes, models, device):
    for m in range(len(models)):

        for x in range(len(DatasetsTypes)):
            for y in range(len(DataGenerationTypes)):

                if (DataGenerationTypes[y] == None):
                    args.DataName = DatasetsTypes[x] + "_Box"
                else:
                    args.DataName = DatasetsTypes[
                        x] + "_" + DataGenerationTypes[y]

                Training = np.load(args.data_dir + "SimulatedTrainingData_" +
                                   args.DataName + "_F_" +
                                   str(args.NumFeatures) + "_TS_" +
                                   str(args.NumTimeSteps) + ".npy")
                TrainingMetaDataset = np.load(args.data_dir +
                                              "SimulatedTrainingMetaData_" +
                                              args.DataName + "_F_" +
                                              str(args.NumFeatures) + "_TS_" +
                                              str(args.NumTimeSteps) + ".npy")
                TrainingLabel = TrainingMetaDataset[:, 0]

                Testing = np.load(args.data_dir + "SimulatedTestingData_" +
                                  args.DataName + "_F_" +
                                  str(args.NumFeatures) + "_TS_" +
                                  str(args.NumTimeSteps) + ".npy")
                TestingDataset_MetaData = np.load(args.data_dir +
                                                  "SimulatedTestingMetaData_" +
                                                  args.DataName + "_F_" +
                                                  str(args.NumFeatures) +
                                                  "_TS_" +
                                                  str(args.NumTimeSteps) +
                                                  ".npy")
                TestingLabel = TestingDataset_MetaData[:, 0]

                Training = Training.reshape(
                    Training.shape[0], Training.shape[1] * Training.shape[2])
                Testing = Testing.reshape(Testing.shape[0],
                                          Testing.shape[1] * Testing.shape[2])

                scaler = MinMaxScaler()
                scaler.fit(Training)
                Training = scaler.transform(Training)
                Testing = scaler.transform(Testing)

                TrainingRNN = Training.reshape(Training.shape[0],
                                               args.NumTimeSteps,
                                               args.NumFeatures)
                TestingRNN = Testing.reshape(Testing.shape[0],
                                             args.NumTimeSteps,
                                             args.NumFeatures)

                train_dataRNN = data_utils.TensorDataset(
                    torch.from_numpy(TrainingRNN),
                    torch.from_numpy(TrainingLabel))
                train_loaderRNN = data_utils.DataLoader(
                    train_dataRNN, batch_size=args.batch_size, shuffle=True)

                test_dataRNN = data_utils.TensorDataset(
                    torch.from_numpy(TestingRNN),
                    torch.from_numpy(TestingLabel))
                test_loaderRNN = data_utils.DataLoader(
                    test_dataRNN, batch_size=args.batch_size, shuffle=False)

                modelName = "Simulated"
                modelName += args.DataName

                saveModelName = "../Models/" + models[m] + "/" + modelName
                saveModelBestName = saveModelName + "_BEST.pkl"

                pretrained_model = torch.load(saveModelBestName,
                                              map_location=device)
                Test_Acc = checkAccuracy(test_loaderRNN, pretrained_model,
                                         args)
                print('{} {} model BestAcc {:.4f}'.format(
                    args.DataName, models[m], Test_Acc))

                if (Test_Acc >= 90):

                    if (args.GradFlag):
                        rescaledGrad = np.zeros((TestingRNN.shape))
                        Grad = Saliency(pretrained_model)

                    if (args.IGFlag):
                        rescaledIG = np.zeros((TestingRNN.shape))
                        IG = IntegratedGradients(pretrained_model)
                    if (args.DLFlag):
                        rescaledDL = np.zeros((TestingRNN.shape))
                        DL = DeepLift(pretrained_model)
                    if (args.GSFlag):
                        rescaledGS = np.zeros((TestingRNN.shape))
                        GS = GradientShap(pretrained_model)
                    if (args.DLSFlag):
                        rescaledDLS = np.zeros((TestingRNN.shape))
                        DLS = DeepLiftShap(pretrained_model)

                    if (args.SGFlag):
                        rescaledSG = np.zeros((TestingRNN.shape))
                        Grad_ = Saliency(pretrained_model)
                        SG = NoiseTunnel(Grad_)

                    if (args.ShapleySamplingFlag):
                        rescaledShapleySampling = np.zeros((TestingRNN.shape))
                        SS = ShapleyValueSampling(pretrained_model)
                    if (args.GSFlag):
                        rescaledFeaturePermutation = np.zeros(
                            (TestingRNN.shape))
                        FP = FeaturePermutation(pretrained_model)
                    if (args.FeatureAblationFlag):
                        rescaledFeatureAblation = np.zeros((TestingRNN.shape))
                        FA = FeatureAblation(pretrained_model)

                    if (args.OcclusionFlag):
                        rescaledOcclusion = np.zeros((TestingRNN.shape))
                        OS = Occlusion(pretrained_model)

                    idx = 0
                    mask = np.zeros((args.NumTimeSteps, args.NumFeatures),
                                    dtype=int)
                    for i in range(args.NumTimeSteps):
                        mask[i, :] = i

                    for i, (samples, labels) in enumerate(test_loaderRNN):

                        print('[{}/{}] {} {} model accuracy {:.2f}'\
                                .format(i,len(test_loaderRNN), models[m], args.DataName, Test_Acc))

                        input = samples.reshape(-1, args.NumTimeSteps,
                                                args.NumFeatures).to(device)
                        input = Variable(input,
                                         volatile=False,
                                         requires_grad=True)

                        batch_size = input.shape[0]
                        baseline_single = torch.from_numpy(
                            np.random.random(input.shape)).to(device)
                        baseline_multiple = torch.from_numpy(
                            np.random.random(
                                (input.shape[0] * 5, input.shape[1],
                                 input.shape[2]))).to(device)
                        inputMask = np.zeros((input.shape))
                        inputMask[:, :, :] = mask
                        inputMask = torch.from_numpy(inputMask).to(device)
                        mask_single = torch.from_numpy(mask).to(device)
                        mask_single = mask_single.reshape(
                            1, args.NumTimeSteps, args.NumFeatures).to(device)
                        labels = torch.tensor(labels.int().tolist()).to(device)

                        if (args.GradFlag):
                            attributions = Grad.attribute(input, \
                                                          target=labels)
                            rescaledGrad[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.IGFlag):
                            attributions = IG.attribute(input,  \
                                                        baselines=baseline_single, \
                                                        target=labels)
                            rescaledIG[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.DLFlag):
                            attributions = DL.attribute(input,  \
                                                        baselines=baseline_single, \
                                                        target=labels)
                            rescaledDL[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.GSFlag):

                            attributions = GS.attribute(input,  \
                                                        baselines=baseline_multiple, \
                                                        stdevs=0.09,\
                                                        target=labels)
                            rescaledGS[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.DLSFlag):

                            attributions = DLS.attribute(input,  \
                                                        baselines=baseline_multiple, \
                                                        target=labels)
                            rescaledDLS[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.SGFlag):
                            attributions = SG.attribute(input, \
                                                        target=labels)
                            rescaledSG[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.ShapleySamplingFlag):
                            attributions = SS.attribute(input, \
                                            baselines=baseline_single, \
                                            target=labels,\
                                            feature_mask=inputMask)
                            rescaledShapleySampling[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.FeaturePermutationFlag):
                            attributions = FP.attribute(input, \
                                            target=labels,
                                            perturbations_per_eval= input.shape[0],\
                                            feature_mask=mask_single)
                            rescaledFeaturePermutation[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.FeatureAblationFlag):
                            attributions = FA.attribute(input, \
                                            target=labels)
                            # perturbations_per_eval= input.shape[0],\
                            # feature_mask=mask_single)
                            rescaledFeatureAblation[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.OcclusionFlag):
                            attributions = OS.attribute(input, \
                                            sliding_window_shapes=(1,args.NumFeatures),
                                            target=labels,
                                            baselines=baseline_single)
                            rescaledOcclusion[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        idx += batch_size

                    if (args.plot):
                        index = random.randint(0, TestingRNN.shape[0] - 1)
                        plotExampleBox(TestingRNN[index, :, :],
                                       args.Saliency_Maps_graphs_dir +
                                       args.DataName + "_" + models[m] +
                                       '_sample',
                                       flip=True)

                        print("Plotting sample", index)
                        if (args.GradFlag):
                            plotExampleBox(rescaledGrad[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_Grad',
                                           greyScale=True,
                                           flip=True)

                        if (args.IGFlag):
                            plotExampleBox(rescaledIG[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_IG',
                                           greyScale=True,
                                           flip=True)

                        if (args.DLFlag):
                            plotExampleBox(rescaledDL[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_DL',
                                           greyScale=True,
                                           flip=True)

                        if (args.GSFlag):
                            plotExampleBox(rescaledGS[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_GS',
                                           greyScale=True,
                                           flip=True)

                        if (args.DLSFlag):
                            plotExampleBox(rescaledDLS[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_DLS',
                                           greyScale=True,
                                           flip=True)

                        if (args.SGFlag):
                            plotExampleBox(rescaledSG[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_SG',
                                           greyScale=True,
                                           flip=True)

                        if (args.ShapleySamplingFlag):
                            plotExampleBox(
                                rescaledShapleySampling[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_ShapleySampling',
                                greyScale=True,
                                flip=True)

                        if (args.FeaturePermutationFlag):
                            plotExampleBox(
                                rescaledFeaturePermutation[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_FeaturePermutation',
                                greyScale=True,
                                flip=True)

                        if (args.FeatureAblationFlag):
                            plotExampleBox(
                                rescaledFeatureAblation[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_FeatureAblation',
                                greyScale=True,
                                flip=True)

                        if (args.OcclusionFlag):
                            plotExampleBox(rescaledOcclusion[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_Occlusion',
                                           greyScale=True,
                                           flip=True)

                    if (args.save):
                        if (args.GradFlag):
                            print("Saving Grad", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_Grad_rescaled", rescaledGrad)

                        if (args.IGFlag):
                            print("Saving IG", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_IG_rescaled", rescaledIG)

                        if (args.DLFlag):
                            print("Saving DL", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_DL_rescaled", rescaledDL)

                        if (args.GSFlag):
                            print("Saving GS", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_GS_rescaled", rescaledGS)

                        if (args.DLSFlag):
                            print("Saving DLS", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_DLS_rescaled", rescaledDLS)

                        if (args.SGFlag):
                            print("Saving SG", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_SG_rescaled", rescaledSG)

                        if (args.ShapleySamplingFlag):
                            print("Saving ShapleySampling",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_ShapleySampling_rescaled",
                                rescaledShapleySampling)

                        if (args.FeaturePermutationFlag):
                            print("Saving FeaturePermutation",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_FeaturePermutation_rescaled",
                                rescaledFeaturePermutation)

                        if (args.FeatureAblationFlag):
                            print("Saving FeatureAblation",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_FeatureAblation_rescaled",
                                rescaledFeatureAblation)

                        if (args.OcclusionFlag):
                            print("Saving Occlusion",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_Occlusion_rescaled",
                                rescaledOcclusion)

                else:
                    logging.basicConfig(filename=args.log_file,
                                        level=logging.DEBUG)

                    logging.debug('{} {} model BestAcc {:.4f}'.format(
                        args.DataName, models[m], Test_Acc))

                    if not os.path.exists(args.ignore_list):
                        with open(args.ignore_list, 'w') as fp:
                            fp.write(args.DataName + '_' + models[m] + '\n')

                    else:
                        with open(args.ignore_list, "a") as fp:
                            fp.write(args.DataName + '_' + models[m] + '\n')