def extract_GS(self, X_train, X_test):
     gs = GradientShap(self.net)
     start = time.time()
     gs_attr_test = gs.attribute(X_test.to(self.device),
                                 X_train.to(self.device))
     print("temps train", time.time() - start)
     return gs_attr_test.detach().cpu().numpy()
Beispiel #2
0
class GradientShapExplainer:
    def __init__(self, model, activation=torch.nn.Softmax(-1)):
        self.device = 'cuda'  #'cuda' if torch.cuda.is_available() else 'cpu'
        self.base_model = model.to(self.device)
        self.base_model.device = self.device
        self.explainer = GradientShap(self.base_model)
        self.activation = activation

    def attribute(self, x, y, retrospective=False):
        x, y = x.to(self.device), y.to(self.device)
        if retrospective:
            score = self.explainer.attribute(x,
                                             target=y.long(),
                                             n_samples=50,
                                             stdevs=0.0001,
                                             baselines=torch.cat(
                                                 [x * 0, x * 1]))
            score = abs(score.cpu().numpy())
        else:
            score = np.zeros(x.shape)

            for t in range(x.shape[-1]):
                x_in = x[:, :, :t + 1]
                pred = self.activation(self.base_model(x_in))
                if type(self.activation).__name__ == type(
                        torch.nn.Softmax(-1)).__name__:
                    target = torch.argmax(pred, -1)
                    imp = self.explainer.attribute(x_in,
                                                   target=target.long(),
                                                   n_samples=50,
                                                   stdevs=0.0001,
                                                   baselines=torch.cat([
                                                       x[:, :, :t + 1] * 0,
                                                       x[:, :, :t + 1] * 1
                                                   ]))
                    score[:, :, t] = imp.cpu().numpy()[:, :, -1]
                else:
                    n_labels = pred.shape[1]
                    if n_labels > 1:
                        imp = torch.zeros(list(x_in.shape) + [n_labels])
                        for l in range(n_labels):
                            target = (pred[:, l] > 0.5).float()  #[:,0]
                            imp[:, :, :, l] = self.explainer.attribute(
                                x_in,
                                target=target.long(),
                                baselines=(x_in * 0))
                        score[:, :,
                              t] = (imp.detach().cpu().numpy()).max(3)[:, :,
                                                                       -1]
                    else:
                        #this is for spike with just one label. and we will explain one cla
                        target = (pred > 0.5).float()[:, 0]
                        imp = self.explainer.attribute(
                            x_in,
                            target=target.long(),
                            baselines=(x[:, :, :t + 1] * 0))
                        score[:, :, t] = abs(imp.detach().cpu().numpy()[:, :,
                                                                        -1])
        return score
Beispiel #3
0
    def compute_gradient_shap(self, img_path, target):

        # open image
        img, transformed_img, input = self.open_image(img_path)

        rand_img_dist = torch.cat([input * 0, input * 1])

        gradient_shap = GradientShap(self.model)
        attributions_gs = gradient_shap.attribute(input,
                                                  n_samples=50,
                                                  stdevs=0.0001,
                                                  baselines=rand_img_dist,
                                                  target=target)
        attributions_gs = np.transpose(
            attributions_gs.squeeze().cpu().detach().numpy(), (1, 2, 0))
        return attributions_gs
Beispiel #4
0
    def test_basic_sensitivity_max_multiple_gradshap(self) -> None:
        model = BasicModel2()
        gs = GradientShap(model)

        input1 = torch.tensor([0.0] * 5)
        input2 = torch.tensor([0.0] * 5)

        baseline1 = torch.arange(0, 2).float() / 1000
        baseline2 = torch.arange(0, 2).float() / 1000

        self.sensitivity_max_assert(
            gs.attribute,
            (input1, input2),
            torch.zeros(5),
            baselines=(baseline1, baseline2),
            max_examples_per_batch=2,
        )

        self.sensitivity_max_assert(
            gs.attribute,
            (input1, input2),
            torch.zeros(5),
            baselines=(baseline1, baseline2),
            max_examples_per_batch=20,
        )
Beispiel #5
0
class gradientshap_explainer:
    def __init__(self, model, train_data):
        model.eval()
        self.explainer = GradientShap(model)
        self.model = model

    def get_feature_importance(self, data):
        data.requires_grad = True
        baseline_dist = torch.randn(data.shape) * 0.001
        return torch.stack([self.explainer.attribute(data, stdevs=0.09, n_samples=4, baselines=baseline_dist,
                                           target=i, return_convergence_delta=True)[0] for i in
                                                range(global_vars.get('n_classes'))], axis=0).detach()
    def __init__(self,
                 normal_path,
                 adv_path,
                 collate_path,
                 model_path,
                 normal_only=False,
                 large_normal=False):
        self.normal_only = normal_only
        self.model = torch.load(model_path)
        self.normal_dataset = XrayDataset(normal_path,
                                          collate_path,
                                          'Cardiomegaly',
                                          1,
                                          normalise=False,
                                          only=None)
        self.adv_dataset = XrayDataset(adv_path,
                                       collate_path,
                                       'Cardiomegaly',
                                       1,
                                       normalise=False,
                                       adv=adv_path)

        self.num_normal = len(self.adv_dataset)

        if large_normal:
            # For training anomaly detection AEs
            self.num_normal = self.num_normal * 5

        self.labels = [0 for i in range(self.num_normal)]

        if not self.normal_only:
            self.labels += [1 for j in range(len(self.adv_dataset))]

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.model = self.model.to(device)
        self.model.eval()

        self.shap = GradientShap(self.model)

        self.baseline = torch.randn(1,
                                    3,
                                    224,
                                    224,
                                    requires_grad=True,
                                    device=device)
Beispiel #7
0
class Explainer():
    def __init__(self, model, stdevs=0.09, n_samples=4):
        self.model = model
        self.explain = GradientShap(model)
        self.stdevs = stdevs
        self.n_samples = n_samples

    def get_attribution_map(self, img, target=None):
        if target is None:
            target = torch.argmax(self.model(img), 1)
        baseline_dist = torch.randn_like(img) * 0.001
        attributions, delta = self.explain.attribute(
            img,
            stdevs=self.stdevs,
            n_samples=self.n_samples,
            baselines=baseline_dist,
            target=target,
            return_convergence_delta=True)
        return attributions
Beispiel #8
0
def find_genes_GSClass(drug, ens, meta, test_tcga_expr):
    from captum.attr import GradientShap
    gs = GradientShap(ens)

    # find genes for Sensitive Class
    sensitive = ['Complete Response', 'Partial Response']

    sen_idx = meta.loc[meta['label'].isin(sensitive)].index
    sen = torch.FloatTensor(test_tcga_expr.loc[sen_idx].values)

    res_idx = meta.loc[~meta['label'].isin(sensitive)].index
    res = torch.FloatTensor(test_tcga_expr.loc[res_idx].values)

    sen_attr = attribute(gs, sen, res, sen_idx)
    res_attr = attribute(gs, res, sen, res_idx)

    sen_genes = get_ranked_list(sen_attr)
    res_genes = get_ranked_list(res_attr)

    out = pd.DataFrame(columns=['sensitive', 'resistant'])
    out['sensitive'] = sen_genes
    out['resistant'] = res_genes

    out.to_csv(res_dir + drug + '/genes.csv', index=False)
Beispiel #9
0
def return_spw_importances_(train_methyl_array,
                            val_methyl_array,
                            interest_col,
                            select_subtypes,
                            capsules_pickle,
                            include_last,
                            n_bins,
                            spw_config,
                            model_state_dict_pkl,
                            batch_size,
                            by_subtype=False):
    ma = MethylationArray.from_pickle(train_methyl_array)
    ma_v = MethylationArray.from_pickle(val_methyl_array)

    try:
        ma.remove_na_samples(interest_col)
        ma_v.remove_na_samples(interest_col)
    except:
        pass

    if select_subtypes:
        ma.pheno = ma.pheno.loc[ma.pheno[interest_col].isin(select_subtypes)]
        ma.beta = ma.beta.loc[ma.pheno.index]
        ma_v.pheno = ma_v.pheno.loc[ma_v.pheno[interest_col].isin(
            select_subtypes)]
        ma_v.beta = ma_v.beta.loc[ma_v.pheno.index]

    capsules_dict = torch.load(capsules_pickle)

    final_modules, modulecpgs, module_names = capsules_dict[
        'final_modules'], capsules_dict['modulecpgs'], capsules_dict[
            'module_names']

    if not include_last:
        ma.beta = ma.beta.loc[:, modulecpgs]
        ma_v.beta = ma_v.beta.loc[:, modulecpgs]

    original_interest_col = interest_col

    if n_bins:
        new_interest_col = interest_col + '_binned'
        ma.pheno.loc[:,
                     new_interest_col], bins = pd.cut(ma.pheno[interest_col],
                                                      bins=n_bins,
                                                      retbins=True)
        ma_v.pheno.loc[:,
                       new_interest_col], _ = pd.cut(ma_v.pheno[interest_col],
                                                     bins=bins,
                                                     retbins=True)
        interest_col = new_interest_col

    datasets = dict()
    datasets['train'] = MethylationDataset(
        ma,
        interest_col,
        modules=final_modules,
        module_names=module_names,
        original_interest_col=original_interest_col,
        run_spw=True)
    datasets['val'] = MethylationDataset(
        ma_v,
        interest_col,
        modules=final_modules,
        module_names=module_names,
        original_interest_col=original_interest_col,
        run_spw=True)

    y_val = datasets['val'].y_label
    y_val_uniq = np.unique(y_val)

    dataloaders = dict()
    dataloaders['train'] = DataLoader(datasets['train'],
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True,
                                      drop_last=True)
    dataloaders['val'] = DataLoader(datasets['val'],
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    drop_last=False)
    n_primary = len(final_modules)

    spw_config = torch.load(spw_config)
    spw_config.pop('module_names')

    model = MethylSPWNet(**spw_config)
    model.load_state_dict(torch.load(model_state_dict_pkl))

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()

    pathway_extractor = model.pathways

    #extract_pathways = lambda modules_x:torch.cat([pathway_extractor[i](module_x) for i,module_x in enumerate(modules_x)],dim=1)

    tensor_data = dict(train=dict(X=[], y=[]), val=dict(X=[], y=[]))

    for k in tensor_data:
        for i, (batch) in enumerate(dataloaders[k]):
            x = batch[0]
            y_true = batch[-1].argmax(1)  #[-2]
            modules_x = batch[1:-1]  #2]
            if torch.cuda.is_available():
                x = x.cuda()
                modules_x = modules_x[0].cuda(
                )  #[module.cuda() for module in modules_x]
            tensor_data[k]['X'].append(
                pathway_extractor(x, modules_x).detach().cpu()
            )  #extract_pathways(modules_x).detach().cpu())
            tensor_data[k]['y'].append(y_true.flatten().view(-1, 1))
        tensor_data[k]['X'] = torch.cat(tensor_data[k]['X'], dim=0)
        tensor_data[k]['y'] = torch.cat(tensor_data[k]['y'], dim=0)
        print(tensor_data[k]['X'].size(), tensor_data[k]['y'].size())
        tensor_data[k] = TensorDataset(tensor_data[k]['X'],
                                       tensor_data[k]['y'])
        dataloaders[k] = DataLoader(tensor_data[k],
                                    batch_size=32,
                                    sampler=ImbalancedDatasetSampler(
                                        tensor_data[k]))

    model = model.output_net
    to_cuda = lambda x: x.cuda() if torch.cuda.is_available() else x
    y = np.unique(tensor_data['train'].tensors[1].numpy().flatten())
    gs = GradientShap(model)
    X_train = torch.cat(
        [next(iter(dataloaders['train']))[0] for i in range(2)], dim=0)
    if torch.cuda.is_available():
        X_train = X_train.cuda()

    #val_loader=iter(dataloaders['val'])

    def return_importances(dataloaders, X_train):
        attributions = []
        for i in range(20):
            batch = next(iter(dataloaders['val']))
            X_test = to_cuda(batch[0])
            y_test = to_cuda(batch[1].flatten())
            attributions.append(
                torch.abs(
                    gs.attribute(
                        X_test,
                        stdevs=0.03,
                        n_samples=200,
                        baselines=X_train,
                        target=y_test,
                        return_convergence_delta=False)))  #torch.tensor(y_i)
        attributions = torch.sum(torch.cat(attributions, dim=0), dim=0)
        importances = pd.DataFrame(
            pd.Series(attributions.detach().cpu().numpy(),
                      index=module_names).sort_values(ascending=False),
            columns=['importances'])
        return importances

    if by_subtype:
        importances = []
        for k in y_val_uniq:
            idx = np.where(y_val == k)[0]
            if len(idx) > 2:
                val_dataset = Subset(tensor_data['val'], idx)
                n_concat = int(np.ceil(64. / len(idx)))
                if n_concat > 1:
                    val_dataset = ConcatDataset([val_dataset] * n_concat)
                #sampler=SubsetRandomSampler(idx)
                dataloaders['val'] = DataLoader(val_dataset,
                                                batch_size=32,
                                                shuffle=True)
                df = return_importances(dataloaders, X_train)
                df['subtype'] = k
                importances.append(df)
        importances = pd.concat(importances)
    else:
        importances = return_importances(dataloaders, X_train)

    return importances
Beispiel #10
0
    def generate_heatmap(self, img):
        if self.is_cpu:
            model = torch.load(self.model_path,
                               map_location=torch.device("cpu"))
        else:
            model = torch.load(self.model_path)
        model = model["model"]
        model.eval()

        transform = transforms.Compose(
            [transforms.Resize((64, 64)),
             transforms.ToTensor()])

        transform_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                   std=[0.229, 0.224, 0.225])

        img_t = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(np.uint8(img_t))

        transformed_img = transform(img)
        input = transform_normalize(transformed_img)
        input = input.unsqueeze(0).to(device)

        output = model(input)
        output = F.softmax(output, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)

        pred_label_idx.squeeze_()

        default_cmap = LinearSegmentedColormap.from_list("custom blue",
                                                         [(0, "#ffffff"),
                                                          (0.25, "#000000"),
                                                          (1, "#000000")],
                                                         N=256)

        gradient_shap = GradientShap(model)

        rand_img_dist = torch.cat([input * 0, input * 1])

        attributions_gs = gradient_shap.attribute(
            input,
            n_samples=50,
            stdevs=0.0001,
            baselines=rand_img_dist,
            target=pred_label_idx,
        )
        out = viz.visualize_image_attr_multiple(
            np.transpose(attributions_gs.squeeze().cpu().detach().numpy(),
                         (1, 2, 0)),
            np.transpose(transformed_img.squeeze().cpu().detach().numpy(),
                         (1, 2, 0)),
            ["original_image", "heat_map"],
            ["all", "absolute_value"],
            cmap=default_cmap,
            show_colorbar=True,
        )

        path = "model_output/" + self.model_path_name + "_heat_map.png"
        out[1][0].get_figure().savefig(path)
        plt.clf()
        plt.close()

        return path
Beispiel #11
0
def measure_model(
    model_version,
    dataset,
    out_folder,
    weights_dir,
    device,
    method=METHODS["gradcam"],
    sample_images=50,
    step=1,
):
    invTrans = get_inverse_normalization_transformation()
    data_dir = os.path.join("data")

    if model_version == "resnet18":
        model = create_resnet18_model(num_of_classes=NUM_OF_CLASSES[dataset])
    elif model_version == "resnet50":
        model = create_resnet50_model(num_of_classes=NUM_OF_CLASSES[dataset])
    elif model_version == "densenet":
        model = create_densenet121_model(
            num_of_classes=NUM_OF_CLASSES[dataset])
    else:
        model = create_efficientnetb0_model(
            num_of_classes=NUM_OF_CLASSES[dataset])

    model.load_state_dict(torch.load(weights_dir))

    # print(model)

    model.eval()
    model.to(device)

    test_dataset = CustomDataset(
        dataset=dataset,
        transformer=get_default_transformation(),
        data_type="test",
        root_dir=data_dir,
        step=step,
    )
    data_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=4)

    try:
        image_ids = random.sample(range(0, test_dataset.__len__()),
                                  sample_images)
    except ValueError:
        raise ValueError(
            f"Image sample number ({sample_images}) exceeded dataset size ({test_dataset.__len__()})."
        )

    classes_map = test_dataset.classes_map

    print(f"Measuring {model_version} on {dataset} dataset, with {method}")
    print("-" * 10)
    pbar = tqdm(total=test_dataset.__len__(), desc="Model test completion")
    multipy_by_inputs = False
    if method == METHODS["ig"]:
        attr_method = IntegratedGradients(model)
        nt_samples = 8
        n_perturb_samples = 3
    if method == METHODS["saliency"]:
        attr_method = Saliency(model)
        nt_samples = 8
        n_perturb_samples = 10
    if method == METHODS["gradcam"]:
        if model_version == "efficientnet":
            attr_method = GuidedGradCam(model, model._conv_stem)
        elif model_version == "densenet":
            attr_method = GuidedGradCam(model, model.features.conv0)
        else:
            attr_method = GuidedGradCam(model, model.conv1)
        nt_samples = 8
        n_perturb_samples = 10
    if method == METHODS["deconv"]:
        attr_method = Deconvolution(model)
        nt_samples = 8
        n_perturb_samples = 10
    if method == METHODS["gradshap"]:
        attr_method = GradientShap(model)
        nt_samples = 8
        if model_version == "efficientnet":
            n_perturb_samples = 3
        elif model_version == "densenet":
            n_perturb_samples = 2
        else:
            n_perturb_samples = 10
    if method == METHODS["gbp"]:
        attr_method = GuidedBackprop(model)
        nt_samples = 8
        n_perturb_samples = 10
    if method == "lime":
        attr_method = Lime(model)
        nt_samples = 8
        n_perturb_samples = 10
        feature_mask = torch.tensor(lime_mask).to(device)
        multipy_by_inputs = True
    if method == METHODS['ig']:
        nt = attr_method
    else:
        nt = NoiseTunnel(attr_method)
    scores = []

    @infidelity_perturb_func_decorator(multipy_by_inputs=multipy_by_inputs)
    def perturb_fn(inputs):
        noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float()
        noise = noise.to(device)
        return inputs - noise

    for input, label in data_loader:
        pbar.update(1)
        inv_input = invTrans(input)
        input = input.to(device)
        input.requires_grad = True
        output = model(input)
        output = F.softmax(output, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)
        prediction_score = prediction_score.cpu().detach().numpy()[0][0]
        pred_label_idx.squeeze_()

        if method == METHODS['gradshap']:
            baseline = torch.randn(input.shape)
            baseline = baseline.to(device)

        if method == "lime":
            attributions = attr_method.attribute(input, target=1, n_samples=50)
        elif method == METHODS['ig']:
            attributions = nt.attribute(
                input,
                target=pred_label_idx,
                n_steps=25,
            )
        elif method == METHODS['gradshap']:
            attributions = nt.attribute(input,
                                        target=pred_label_idx,
                                        baselines=baseline)
        else:
            attributions = nt.attribute(
                input,
                nt_type="smoothgrad",
                nt_samples=nt_samples,
                target=pred_label_idx,
            )

        infid = infidelity(model,
                           perturb_fn,
                           input,
                           attributions,
                           target=pred_label_idx)

        if method == "lime":
            sens = sensitivity_max(
                attr_method.attribute,
                input,
                target=pred_label_idx,
                n_perturb_samples=1,
                n_samples=200,
                feature_mask=feature_mask,
            )
        elif method == METHODS['ig']:
            sens = sensitivity_max(
                nt.attribute,
                input,
                target=pred_label_idx,
                n_perturb_samples=n_perturb_samples,
                n_steps=25,
            )
        elif method == METHODS['gradshap']:
            sens = sensitivity_max(nt.attribute,
                                   input,
                                   target=pred_label_idx,
                                   n_perturb_samples=n_perturb_samples,
                                   baselines=baseline)
        else:
            sens = sensitivity_max(
                nt.attribute,
                input,
                target=pred_label_idx,
                n_perturb_samples=n_perturb_samples,
            )
        inf_value = infid.cpu().detach().numpy()[0]
        sens_value = sens.cpu().detach().numpy()[0]
        if pbar.n in image_ids:
            attr_data = attributions.squeeze().cpu().detach().numpy()
            fig, ax = viz.visualize_image_attr_multiple(
                np.transpose(attr_data, (1, 2, 0)),
                np.transpose(inv_input.squeeze().cpu().detach().numpy(),
                             (1, 2, 0)),
                ["original_image", "heat_map"],
                ["all", "positive"],
                titles=["original_image", "heat_map"],
                cmap=default_cmap,
                show_colorbar=True,
                use_pyplot=False,
                fig_size=(8, 6),
            )
            ax[0].set_xlabel(
                f"Infidelity: {'{0:.6f}'.format(inf_value)}\n Sensitivity: {'{0:.6f}'.format(sens_value)}"
            )
            fig.suptitle(
                f"True: {classes_map[str(label.numpy()[0])][0]}, Pred: {classes_map[str(pred_label_idx.item())][0]}\nScore: {'{0:.4f}'.format(prediction_score)}",
                fontsize=16,
            )
            fig.savefig(
                os.path.join(
                    out_folder,
                    f"{str(pbar.n)}-{classes_map[str(label.numpy()[0])][0]}-{classes_map[str(pred_label_idx.item())][0]}.png",
                ))
            plt.close(fig)
            # if pbar.n > 25:
            #     break

        scores.append([inf_value, sens_value])
    pbar.close()

    np.savetxt(
        os.path.join(out_folder, f"{model_version}-{dataset}-{method}.csv"),
        np.array(scores),
        delimiter=",",
        header="infidelity,sensitivity",
    )

    print(f"Artifacts stored at {out_folder}")
def main(args):

    train_loader, test_loader = data_generator(args.data_dir,1)

    for m in range(len(models)):

        model_name = "model_{}_NumFeatures_{}".format(models[m],args.NumFeatures)
        model_filename = args.model_dir + 'm_' + model_name + '.pt'
        pretrained_model = torch.load(open(model_filename, "rb"),map_location=device) 
        pretrained_model.to(device)



        if(args.GradFlag):
            Grad = Saliency(pretrained_model)
        if(args.IGFlag):
            IG = IntegratedGradients(pretrained_model)
        if(args.DLFlag):
            DL = DeepLift(pretrained_model)
        if(args.GSFlag):
            GS = GradientShap(pretrained_model)
        if(args.DLSFlag):
            DLS = DeepLiftShap(pretrained_model)                 
        if(args.SGFlag):
            Grad_ = Saliency(pretrained_model)
            SG = NoiseTunnel(Grad_)
        if(args.ShapleySamplingFlag):
            SS = ShapleyValueSampling(pretrained_model)
        if(args.GSFlag):
            FP = FeaturePermutation(pretrained_model)
        if(args.FeatureAblationFlag):
            FA = FeatureAblation(pretrained_model)         
        if(args.OcclusionFlag):
            OS = Occlusion(pretrained_model)

        timeMask=np.zeros((args.NumTimeSteps, args.NumFeatures),dtype=int)
        featureMask=np.zeros((args.NumTimeSteps, args.NumFeatures),dtype=int)
        for i in  range (args.NumTimeSteps):
            timeMask[i,:]=i

        for i in  range (args.NumTimeSteps):
            featureMask[:,i]=i

        indexes = [[] for i in range(5,10)]
        for i ,(data, target) in enumerate(test_loader):
            if(target==5 or target==6 or target==7 or target==8 or target==9):
                index=target-5

                if(len(indexes[index])<1):
                    indexes[index].append(i)
        for j, index in enumerate(indexes):
            print(index)
        # indexes = [[21],[17],[84],[9]]

        for j, index in enumerate(indexes):
            print("Getting Saliency for number", j+1)
            for i, (data, target) in enumerate(test_loader):
                if(i in index):
                        
                    labels =  target.to(device)
             
                    input = data.reshape(-1, args.NumTimeSteps, args.NumFeatures).to(device)
                    input = Variable(input,  volatile=False, requires_grad=True)

                    baseline_single=torch.Tensor(np.random.random(input.shape)).to(device)
                    baseline_multiple=torch.Tensor(np.random.random((input.shape[0]*5,input.shape[1],input.shape[2]))).to(device)
                    inputMask= np.zeros((input.shape))
                    inputMask[:,:,:]=timeMask
                    inputMask =torch.Tensor(inputMask).to(device)
                    mask_single= torch.Tensor(timeMask).to(device)
                    mask_single=mask_single.reshape(1,args.NumTimeSteps, args.NumFeatures).to(device)

                    Data=data.reshape(args.NumTimeSteps, args.NumFeatures).data.cpu().numpy()
                    
                    target_=int(target.data.cpu().numpy()[0])

                    plotExampleBox(Data,args.Graph_dir+'Sample_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.GradFlag):
                        attributions = Grad.attribute(input, \
                                                      target=labels)
                        
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_Grad_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(Grad,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=None)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_Grad_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)



                    if(args.IGFlag):
                        attributions = IG.attribute(input,  \
                                                    baselines=baseline_single, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_IG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(IG,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_IG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.DLFlag):
                        attributions = DL.attribute(input,  \
                                                    baselines=baseline_single, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_DL_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(DL,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_DL_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.GSFlag):

                        attributions = GS.attribute(input,  \
                                                    baselines=baseline_multiple, \
                                                    stdevs=0.09,\
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_GS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)

 
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(GS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_multiple)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_GS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                    if(args.DLSFlag):

                        attributions = DLS.attribute(input,  \
                                                    baselines=baseline_multiple, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_DLS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(DLS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_multiple)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_DLS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)



                    if(args.SGFlag):
                        attributions = SG.attribute(input, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_SG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(SG,input, args.NumFeatures,args.NumTimeSteps, labels)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_SG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                    if(args.ShapleySamplingFlag):
                        attributions = SS.attribute(input, \
                                        baselines=baseline_single, \
                                        target=labels,\
                                        feature_mask=inputMask)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_SVS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(SS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single,hasFeatureMask=inputMask)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_SVS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                    # if(args.FeaturePermutationFlag):
                    #     attributions = FP.attribute(input, \
                    #                     target=labels),
                    #                     # perturbations_per_eval= 1,\
                    #                     # feature_mask=mask_single)
                    #     saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                    #     plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FP',greyScale=True)


                    if(args.FeatureAblationFlag):
                        attributions = FA.attribute(input, \
                                        target=labels)
                                        # perturbations_per_eval= input.shape[0],\
                                        # feature_mask=mask_single)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FA_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(FA,input, args.NumFeatures,args.NumTimeSteps, labels)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_FA_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)

                    if(args.OcclusionFlag):
                        attributions = OS.attribute(input, \
                                        sliding_window_shapes=(1,int(args.NumFeatures/10)),
                                        target=labels,
                                        baselines=baseline_single)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FO_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(OS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single,hasSliding_window_shapes= (1,int(args.NumFeatures/10)))
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_FO_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
Beispiel #13
0
 def __init__(self, model, activation=torch.nn.Softmax(-1)):
     self.device = 'cuda'  #'cuda' if torch.cuda.is_available() else 'cpu'
     self.base_model = model.to(self.device)
     self.base_model.device = self.device
     self.explainer = GradientShap(self.base_model)
     self.activation = activation
Beispiel #14
0
def evaluation_ten_classes(initiate_or_load_model,
                           config_data,
                           singleton_scope=False,
                           reshape_size=None,
                           FIND_OPTIM_BRANCH_MODEL=False,
                           realtime_update=False,
                           ALLOW_ADHOC_NOPTIM=False):
    from pipeline.training.training_utils import prepare_save_dirs
    xai_mode = config_data['xai_mode']
    MODEL_DIR, INFO_DIR, CACHE_FOLDER_DIR = prepare_save_dirs(config_data)

    ############################
    VERBOSE = 0
    ############################

    if not FIND_OPTIM_BRANCH_MODEL:
        print(
            'Using the following the model from (only) continuous training for xai evaluation [%s]'
            % (str(xai_mode)))
        net, evaluator = initiate_or_load_model(MODEL_DIR,
                                                INFO_DIR,
                                                config_data,
                                                verbose=VERBOSE)
    else:
        BRANCH_FOLDER_DIR = MODEL_DIR[:MODEL_DIR.find('.model')] + '.%s' % (
            str(config_data['branch_name_label']))
        BRANCH_MODEL_DIR = os.path.join(
            BRANCH_FOLDER_DIR,
            '%s.%s.model' % (str(config_data['model_name']),
                             str(config_data['branch_name_label'])))
        # BRANCH_MODEL_DIR = MODEL_DIR[:MODEL_DIR.find('.model')] + '.%s.model'%(str(config_data['branch_name_label']))

        if ALLOW_ADHOC_NOPTIM:  # this is intended only for debug runs
            print('<< [EXY1] ALLOWING ADHOC NOPTIM >>')
            import shutil
            shutil.copyfile(BRANCH_MODEL_DIR, BRANCH_MODEL_DIR + '.noptim')

        if os.path.exists(BRANCH_MODEL_DIR + '.optim'):
            BRANCH_MODEL_DIR = BRANCH_MODEL_DIR + '.optim'
            print(
                '  Using the OPTIMIZED branch model for [%s] xai evaluation: %s'
                % (str(xai_mode), str(BRANCH_MODEL_DIR)))
        elif os.path.exists(BRANCH_MODEL_DIR + '.noptim'):
            BRANCH_MODEL_DIR = BRANCH_MODEL_DIR + '.noptim'
            print(
                '  Using the partially optimized branch model for [%s] xai evaluation: %s'
                % (str(xai_mode), str(BRANCH_MODEL_DIR)))
        else:
            raise RuntimeError(
                'Attempting to find .optim or .noptim model, but not found.')
        if VERBOSE >= 250:
            print(
                '  """You may see a warning by pytorch for ReLu backward hook. It has been fixed externally, so you can ignore it."""'
            )
        net, evaluator = initiate_or_load_model(BRANCH_MODEL_DIR,
                                                INFO_DIR,
                                                config_data,
                                                verbose=VERBOSE)

    if xai_mode == 'Saliency': attrmodel = Saliency(net)
    elif xai_mode == 'IntegratedGradients':
        attrmodel = IntegratedGradients(net)
    elif xai_mode == 'InputXGradient':
        attrmodel = InputXGradient(net)
    elif xai_mode == 'DeepLift':
        attrmodel = DeepLift(net)
    elif xai_mode == 'GuidedBackprop':
        attrmodel = GuidedBackprop(net)
    elif xai_mode == 'GuidedGradCam':
        attrmodel = GuidedGradCam(net, net.select_first_layer())  # first layer
    elif xai_mode == 'Deconvolution':
        attrmodel = Deconvolution(net)
    elif xai_mode == 'GradientShap':
        attrmodel = GradientShap(net)
    elif xai_mode == 'DeepLiftShap':
        attrmodel = DeepLiftShap(net)
    else:
        raise RuntimeError('No valid attribution selected.')

    if singleton_scope:  # just to observe a single datapoint, mostly for debugging
        singleton_scope_oberservation(net, attrmodel, config_data,
                                      CACHE_FOLDER_DIR)
    else:
        aggregate_evaluation(net,
                             attrmodel,
                             config_data,
                             CACHE_FOLDER_DIR,
                             reshape_size=reshape_size,
                             realtime_update=realtime_update,
                             EVALUATE_BRANCH=FIND_OPTIM_BRANCH_MODEL)
avg_ig_vals_dict = {i: avg_ig_vals[i] for i in range(len(avg_ig_vals))}

# Now we can get only the top X many results. ordered stores the keys in order
ordered = sorted(avg_ig_vals_dict,
                 key=lambda i: abs(avg_ig_vals_dict[i]),
                 reverse=True)
top = {i: avg_ig_vals_dict[i] for i in ordered[:args.num_ig]}

vis_importance(list(top),
               list(top.values()),
               plotter_instance=plotter,
               title="Integrated Gradients (Top {})".format(args.num_ig))
print('======================= Done!')

# Try looking at SHAP values now
shap = GradientShap(model)

print(
    "======================= Calculating SHAP Values (DeepLift Approximation)..."
)
# We calculate it using only one sample
batch_size = 1

# Reset our dataloader
test_loader = DataLoader(dataset=test_set,
                         batch_size=batch_size,
                         shuffle=False,
                         collate_fn=visit_collate_fn,
                         num_workers=0)

if args.mean:
    def __init__(self, predictor: Predictor):

        CaptumAttribution.__init__(self, 'gradshap', predictor)

        self.submodel = self.predictor._model.captum_sub_model()
        GradientShap.__init__(self, self.submodel)
Beispiel #17
0
def run_saliency_methods(saliency_methods,
                         pretrained_model,
                         test_shape,
                         train_loader,
                         test_loader,
                         device,
                         model_type,
                         model_name,
                         saliency_dir,
                         tsr_graph_dir=None,
                         tsr_inputs_to_graph=()):
    _, num_timesteps, num_features = test_shape

    run_grad = "Grad" in saliency_methods
    run_grad_tsr = "Grad_TSR" in saliency_methods
    run_ig = "IG" in saliency_methods
    run_ig_tsr = "IG_TSR" in saliency_methods
    run_dl = "DL" in saliency_methods
    run_gs = "GS" in saliency_methods
    run_dls = "DLS" in saliency_methods
    run_dls_tsr = "DLS_TSR" in saliency_methods
    run_sg = "SG" in saliency_methods
    run_shapley_sampling = "ShapleySampling" in saliency_methods
    run_feature_permutation = "FeaturePermutation" in saliency_methods
    run_feature_ablation = "FeatureAblation" in saliency_methods
    run_occlusion = "Occlusion" in saliency_methods
    run_fit = "FIT" in saliency_methods
    run_ifit = "IFIT" in saliency_methods
    run_wfit = "WFIT" in saliency_methods
    run_iwfit = "IWFIT" in saliency_methods

    if run_grad or run_grad_tsr:
        Grad = Saliency(pretrained_model)
    if run_grad:
        rescaledGrad = np.zeros(test_shape)
    if run_grad_tsr:
        rescaledGrad_TSR = np.zeros(test_shape)

    if run_ig or run_ig_tsr:
        IG = IntegratedGradients(pretrained_model)
    if run_ig:
        rescaledIG = np.zeros(test_shape)
    if run_ig_tsr:
        rescaledIG_TSR = np.zeros(test_shape)

    if run_dl:
        rescaledDL = np.zeros(test_shape)
        DL = DeepLift(pretrained_model)

    if run_gs:
        rescaledGS = np.zeros(test_shape)
        GS = GradientShap(pretrained_model)

    if run_dls or run_dls_tsr:
        DLS = DeepLiftShap(pretrained_model)
    if run_dls:
        rescaledDLS = np.zeros(test_shape)
    if run_dls_tsr:
        rescaledDLS_TSR = np.zeros(test_shape)

    if run_sg:
        rescaledSG = np.zeros(test_shape)
        Grad_ = Saliency(pretrained_model)
        SG = NoiseTunnel(Grad_)

    if run_shapley_sampling:
        rescaledShapleySampling = np.zeros(test_shape)
        SS = ShapleyValueSampling(pretrained_model)

    if run_gs:
        rescaledFeaturePermutation = np.zeros(test_shape)
        FP = FeaturePermutation(pretrained_model)

    if run_feature_ablation:
        rescaledFeatureAblation = np.zeros(test_shape)
        FA = FeatureAblation(pretrained_model)

    if run_occlusion:
        rescaledOcclusion = np.zeros(test_shape)
        OS = Occlusion(pretrained_model)

    if run_fit:
        rescaledFIT = np.zeros(test_shape)
        FIT = FITExplainer(pretrained_model, ft_dim_last=True)
        generator = JointFeatureGenerator(num_features, data='none')
        # TODO: Increase epochs
        FIT.fit_generator(generator, train_loader, test_loader, n_epochs=300)

    if run_ifit:
        rescaledIFIT = np.zeros(test_shape)
    if run_wfit:
        rescaledWFIT = np.zeros(test_shape)
    if run_iwfit:
        rescaledIWFIT = np.zeros(test_shape)

    idx = 0
    mask = np.zeros((num_timesteps, num_features), dtype=int)
    for i in range(num_timesteps):
        mask[i, :] = i

    for i, (samples, labels) in enumerate(test_loader):
        input = samples.reshape(-1, num_timesteps, num_features).to(device)
        input = Variable(input, volatile=False, requires_grad=True)

        batch_size = input.shape[0]
        baseline_single = torch.from_numpy(np.random.random(
            input.shape)).to(device)
        baseline_multiple = torch.from_numpy(
            np.random.random((input.shape[0] * 5, input.shape[1],
                              input.shape[2]))).to(device)
        inputMask = np.zeros((input.shape))
        inputMask[:, :, :] = mask
        inputMask = torch.from_numpy(inputMask).to(device)
        mask_single = torch.from_numpy(mask).to(device)
        mask_single = mask_single.reshape(1, num_timesteps,
                                          num_features).to(device)
        labels = torch.tensor(labels.int().tolist()).to(device)

        if run_grad:
            attributions = Grad.attribute(input, target=labels)
            rescaledGrad[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)
        if run_grad_tsr:
            rescaledGrad_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                Grad,
                input,
                labels,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_Grad_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_ig:
            attributions = IG.attribute(input,
                                        baselines=baseline_single,
                                        target=labels)
            rescaledIG[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)
        if run_ig_tsr:
            rescaledIG_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                IG,
                input,
                labels,
                baseline=baseline_single,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_IG_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_dl:
            attributions = DL.attribute(input,
                                        baselines=baseline_single,
                                        target=labels)
            rescaledDL[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_gs:
            attributions = GS.attribute(input,
                                        baselines=baseline_multiple,
                                        stdevs=0.09,
                                        target=labels)
            rescaledGS[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_dls:
            attributions = DLS.attribute(input,
                                         baselines=baseline_multiple,
                                         target=labels)
            rescaledDLS[idx:idx +
                        batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                            num_timesteps, num_features, attributions)
        if run_dls_tsr:
            rescaledDLS_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                DLS,
                input,
                labels,
                baseline=baseline_multiple,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_DLS_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_sg:
            attributions = SG.attribute(input, target=labels)
            rescaledSG[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_shapley_sampling:
            attributions = SS.attribute(input,
                                        baselines=baseline_single,
                                        target=labels,
                                        feature_mask=inputMask)
            rescaledShapleySampling[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_feature_permutation:
            attributions = FP.attribute(input,
                                        target=labels,
                                        perturbations_per_eval=input.shape[0],
                                        feature_mask=mask_single)
            rescaledFeaturePermutation[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_feature_ablation:
            attributions = FA.attribute(input, target=labels)
            # perturbations_per_eval= input.shape[0],\
            # feature_mask=mask_single)
            rescaledFeatureAblation[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_occlusion:
            attributions = OS.attribute(input,
                                        sliding_window_shapes=(1,
                                                               num_features),
                                        target=labels,
                                        baselines=baseline_single)
            rescaledOcclusion[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_fit:
            attributions = torch.from_numpy(FIT.attribute(input, labels))
            rescaledFIT[idx:idx +
                        batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                            num_timesteps, num_features, attributions)

        if run_ifit:
            attributions = torch.from_numpy(
                inverse_fit_attribute(input,
                                      pretrained_model,
                                      ft_dim_last=True))
            rescaledIFIT[idx:idx + batch_size, :, :] = attributions

        if run_wfit:
            attributions = torch.from_numpy(
                wfit_attribute(input,
                               pretrained_model,
                               N=test_shape[1],
                               ft_dim_last=True,
                               single_label=True))
            rescaledWFIT[idx:idx + batch_size, :, :] = attributions

        if run_iwfit:
            attributions = torch.from_numpy(
                wfit_attribute(input,
                               pretrained_model,
                               N=test_shape[1],
                               ft_dim_last=True,
                               single_label=True,
                               inverse=True))
            rescaledIWFIT[idx:idx + batch_size, :, :] = attributions

        idx += batch_size

    if run_grad:
        print("Saving Grad", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_Grad_rescaled",
            rescaledGrad)
    if run_grad_tsr:
        print("Saving Grad_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_Grad_TSR_rescaled", rescaledGrad_TSR)

    if run_ig:
        print("Saving IG", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_IG_rescaled",
                rescaledIG)
    if run_ig_tsr:
        print("Saving IG_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IG_TSR_rescaled",
            rescaledIG_TSR)

    if run_dl:
        print("Saving DL", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_DL_rescaled",
                rescaledDL)

    if run_gs:
        print("Saving GS", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_GS_rescaled",
                rescaledGS)

    if run_dls:
        print("Saving DLS", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_DLS_rescaled",
                rescaledDLS)
    if run_dls_tsr:
        print("Saving DLS_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_DLS_TSR_rescaled",
            rescaledDLS_TSR)

    if run_sg:
        print("Saving SG", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_SG_rescaled",
                rescaledSG)

    if run_shapley_sampling:
        print("Saving ShapleySampling", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_ShapleySampling_rescaled", rescaledShapleySampling)

    if run_feature_permutation:
        print("Saving FeaturePermutation", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_FeaturePermutation_rescaled", rescaledFeaturePermutation)

    if run_feature_ablation:
        print("Saving FeatureAblation", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_FeatureAblation_rescaled", rescaledFeatureAblation)

    if run_occlusion:
        print("Saving Occlusion", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_Occlusion_rescaled", rescaledOcclusion)

    if run_fit:
        print("Saving FIT", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_FIT_rescaled",
                rescaledFIT)

    if run_ifit:
        print("Saving IFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IFIT_rescaled",
            rescaledIFIT)

    if run_wfit:
        print("Saving WFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_WFIT_rescaled",
            rescaledWFIT)

    if run_iwfit:
        print("Saving IWFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IWFIT_rescaled",
            rescaledIWFIT)
Beispiel #18
0
 def __init__(self, model, stdevs=0.09, n_samples=4):
     self.model = model
     self.explain = GradientShap(model)
     self.stdevs = stdevs
     self.n_samples = n_samples
Beispiel #19
0
 def __init__(self, model, train_data):
     model.eval()
     self.explainer = GradientShap(model)
     self.model = model
Beispiel #20
0
def main(args):

    warnings.filterwarnings("ignore")
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Executing on device:', device)

    # --- Get Threshold ---
    model = './model/base' + str(args.model_index) + '.pkl'
    threshold, _ = get_threshold(model_dir=model,
                                 use_cached=args.use_cached,
                                 search_space=np.linspace(
                                     0, 1, args.search_num))

    # --- Load Model ---
    model = torch.load(model, map_location=device).to(device)
    model = model.eval()

    # --- Fix Seed ---
    torch.manual_seed(123)
    np.random.seed(123)

    # --- Label ---
    label_dir = '../CheXpert-v1.0-small/valid.csv'
    label = pd.read_csv(label_dir)
    target_label = np.array(list(label.keys())[5:])
    target_obsrv = np.array([8, 2, 6, 5, 10])
    label = label.values
    label_gd = label[:, 5:]

    # --- Image ---
    img_index = args.image_index
    img_dir = '../'
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
    print('Patient:', label[img_index][0])
    img = Image.open(os.path.join(img_dir, label[img_index][0])).convert('RGB')
    img = transform(img)
    img_transformed = img / 255
    img_transformed = img_transformed.unsqueeze(0).to(device)

    # --- Predict ---
    pred = model(img_transformed)
    print()
    print('[Prediction]')
    print('{:17s}|{:4s}|{:4s}|{:4s}|{:3s}'.format('Pathology', 'Prob', 'Thrs',
                                                  'Pred', 'Ans'))
    for lbl, prd, thrsh, gd in zip(target_label[target_obsrv],
                                   pred[0][target_obsrv],
                                   threshold[target_obsrv],
                                   label_gd[img_index][target_obsrv]):
        print('{:17s}:{:4.2f} {:4.2f} {:4d} {:3d}'.format(
            lbl, prd.item(), thrsh, int(prd.item() > thrsh), int(gd)))

    print()
    del pred
    torch.cuda.empty_cache()
    gc.collect()
    model = model.to(torch.device('cpu'))
    img_transformed = img_transformed.to(torch.device('cpu'))

    # --- Visualization ---
    pathology_input = input(
        'Please enter which pathology to visualize:\n[0]Atelectasis\n[1]Cardiomegaly\n[2]Consolidation\n[3]Edema\n[4]Pleural Effusion\n[5]Exit\n'
    )
    if pathology_input == '0':
        pathology = 8
        print('Diagnosis on Atelectasis')
    elif pathology_input == '1':
        pathology = 2
        print('Diagnosis on Cardiomegaly')
    elif pathology_input == '2':
        pathology = 6
        print('Diagnosis on Consolidation')
    elif pathology_input == '3':
        pathology = 5
        print('Diagnosis on Edema')
    elif pathology_input == '4':
        pathology = 10
        print('Diagnosis on Pleural Effusion')
    elif pathology_input == '5':
        print('Exiting...')
        return
    else:
        raise NotImplementedError('Only 0-5 are valid input values')

    default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                     [(0, '#ffffff'),
                                                      (0.25, '#000000'),
                                                      (1, '#000000')],
                                                     N=256)
    print()
    method_input = input(
        'Please enter which method to visualize:\n[0]GradientShap\n[1]DeepLift\n[2]Exit\n'
    )
    if method_input == '0':
        print('Using GradientShap')
        # --- Gradient Shap ---
        gradient_shap = GradientShap(model)

        # === baseline distribution ===
        rand_img_dist = torch.cat([img_transformed * 0, img_transformed * 1])

        attributions_gs = gradient_shap.attribute(img_transformed,
                                                  n_samples=50,
                                                  stdevs=0.0001,
                                                  baselines=rand_img_dist,
                                                  target=pathology)
        _ = viz.visualize_image_attr_multiple(
            np.transpose(attributions_gs.squeeze().cpu().detach().numpy(),
                         (1, 2, 0)),
            np.transpose(img.squeeze().cpu().detach().numpy(), (1, 2, 0)),
            ["original_image", "heat_map"], ["all", "absolute_value"],
            cmap=default_cmap,
            show_colorbar=True)
        del attributions_gs
    elif method_input == '1':
        print('Using DeepLIFT')
        # --- Deep Lift ---
        model = model_transform(model)
        dl = DeepLift(model)
        attr_dl = dl.attribute(img_transformed,
                               target=pathology,
                               baselines=img_transformed * 0)
        _ = viz.visualize_image_attr_multiple(
            np.transpose(attr_dl.squeeze().cpu().detach().numpy(), (1, 2, 0)),
            np.transpose(img.squeeze().cpu().detach().numpy(), (1, 2, 0)),
            ["original_image", "heat_map"], ["all", "positive"],
            cmap=default_cmap,
            show_colorbar=True)
        del attr_dl
    elif method_input == '2':
        print('Exiting...')
        return
    else:
        raise NotImplementedError('Only 0-2 are valid input values')
    """
    elif method_input == '2':
        print('Using Integrated Gradients')
        # --- Integrated Gradients ---
        integrated_gradients = IntegratedGradients(model)
        attributions_ig = integrated_gradients.attribute(img_transformed, target=pathology, n_steps=200)
        _ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                                              np.transpose(img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                              method=["original_image", "heat_map"],
                                              cmap=default_cmap,
                                              show_colorbar=True,
                                              sign=["all", "positive"])
        del attributions_ig
    elif method_input == '3':
        print('Using Noise Tunnel')
        # --- Noise Tunnel ---
        integrated_gradients = IntegratedGradients(model)
        noise_tunnel = NoiseTunnel(integrated_gradients)
        attributions_ig_nt = noise_tunnel.attribute(img_transformed, n_samples=10, nt_type='smoothgrad_sq', target=pathology)
        _ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
                                              np.transpose(img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                              ["original_image", "heat_map"],
                                              ["all", "positive"],
                                              cmap=default_cmap,
                                              show_colorbar=True)
        del attributions_ig_nt
    """

    gc.collect()

    return
Beispiel #21
0
def measure_filter_model(
    model_version,
    dataset,
    out_folder,
    weights_dir,
    device,
    method=METHODS["gradcam"],
    sample_images=50,
    step=1,
    use_infidelity=False,
    use_sensitivity=False,
    render=False,
    ids=None,
):
    invTrans = get_inverse_normalization_transformation()
    data_dir = os.path.join("data")

    if model_version == "resnet18":
        model = create_resnet18_model(num_of_classes=NUM_OF_CLASSES[dataset])
    elif model_version == "resnet50":
        model = create_resnet50_model(num_of_classes=NUM_OF_CLASSES[dataset])
    elif model_version == "densenet":
        model = create_densenet121_model(
            num_of_classes=NUM_OF_CLASSES[dataset])
    else:
        model = create_efficientnetb0_model(
            num_of_classes=NUM_OF_CLASSES[dataset])

    model.load_state_dict(torch.load(weights_dir))

    # print(model)

    model.eval()
    model.to(device)

    test_dataset = CustomDataset(
        dataset=dataset,
        transformer=get_default_transformation(),
        data_type="test",
        root_dir=data_dir,
        step=step,
        add_filters=True,
        ids=ids,
    )
    data_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=4)

    try:
        image_ids = random.sample(range(0, test_dataset.__len__()),
                                  test_dataset.__len__())
    except ValueError:
        raise ValueError(
            f"Image sample number ({test_dataset.__len__()}) exceeded dataset size ({test_dataset.__len__()})."
        )

    classes_map = test_dataset.classes_map

    print(f"Measuring {model_version} on {dataset} dataset, with {method}")
    print("-" * 10)
    pbar = tqdm(total=test_dataset.__len__(), desc="Model test completion")
    multipy_by_inputs = False
    if method == METHODS["ig"]:
        attr_method = IntegratedGradients(model)
        nt_samples = 1
        n_perturb_samples = 1
    if method == METHODS["saliency"]:
        attr_method = Saliency(model)
        nt_samples = 8
        n_perturb_samples = 2
    if method == METHODS["gradcam"]:
        if model_version == "efficientnet":
            attr_method = GuidedGradCam(model, model._conv_stem)
        elif model_version == "densenet":
            attr_method = GuidedGradCam(model, model.features.conv0)
        else:
            attr_method = GuidedGradCam(model, model.conv1)
        nt_samples = 8
        n_perturb_samples = 2
    if method == METHODS["deconv"]:
        attr_method = Deconvolution(model)
        nt_samples = 8
        n_perturb_samples = 2
    if method == METHODS["gradshap"]:
        attr_method = GradientShap(model)
        nt_samples = 8
        n_perturb_samples = 2
    if method == METHODS["gbp"]:
        attr_method = GuidedBackprop(model)
        nt_samples = 8
        n_perturb_samples = 2
    if method == "lime":
        attr_method = Lime(model)
        nt_samples = 8
        n_perturb_samples = 2
        feature_mask = torch.tensor(lime_mask).to(device)
        multipy_by_inputs = True
    if method == METHODS["ig"]:
        nt = attr_method
    else:
        nt = NoiseTunnel(attr_method)
    scores = []

    @infidelity_perturb_func_decorator(multipy_by_inputs=multipy_by_inputs)
    def perturb_fn(inputs):
        noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float()
        noise = noise.to(device)
        return inputs - noise

    OUR_FILTERS = [
        "none",
        "fx_freaky_details 2,10,1,11,0,32,0",
        "normalize_local 8,10",
        "fx_boost_chroma 90,0,0",
        "fx_mighty_details 25,1,25,1,11,0",
        "sharpen 300",
    ]
    idx = 0
    filter_count = 0
    filter_attrs = {filter_name: [] for filter_name in OUR_FILTERS}
    predicted_main_class = 0
    for input, label in data_loader:
        pbar.update(1)
        inv_input = invTrans(input)
        input = input.to(device)
        input.requires_grad = True
        output = model(input)
        output = F.softmax(output, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)
        prediction_score = prediction_score.cpu().detach().numpy()[0][0]
        pred_label_idx.squeeze_()
        if OUR_FILTERS[filter_count] == 'none':
            predicted_main_class = pred_label_idx.item()

        if method == METHODS["gradshap"]:
            baseline = torch.randn(input.shape)
            baseline = baseline.to(device)

        if method == "lime":
            attributions = attr_method.attribute(input, target=1, n_samples=50)
        elif method == METHODS["ig"]:
            attributions = nt.attribute(
                input,
                target=predicted_main_class,
                n_steps=25,
            )
        elif method == METHODS["gradshap"]:
            attributions = nt.attribute(input,
                                        target=predicted_main_class,
                                        baselines=baseline)
        else:
            attributions = nt.attribute(
                input,
                nt_type="smoothgrad",
                nt_samples=nt_samples,
                target=predicted_main_class,
            )

        if use_infidelity:
            infid = infidelity(model,
                               perturb_fn,
                               input,
                               attributions,
                               target=predicted_main_class)
            inf_value = infid.cpu().detach().numpy()[0]
        else:
            inf_value = 0

        if use_sensitivity:
            if method == "lime":
                sens = sensitivity_max(
                    attr_method.attribute,
                    input,
                    target=predicted_main_class,
                    n_perturb_samples=1,
                    n_samples=200,
                    feature_mask=feature_mask,
                )
            elif method == METHODS["ig"]:
                sens = sensitivity_max(
                    nt.attribute,
                    input,
                    target=predicted_main_class,
                    n_perturb_samples=n_perturb_samples,
                    n_steps=25,
                )
            elif method == METHODS["gradshap"]:
                sens = sensitivity_max(
                    nt.attribute,
                    input,
                    target=predicted_main_class,
                    n_perturb_samples=n_perturb_samples,
                    baselines=baseline,
                )
            else:
                sens = sensitivity_max(
                    nt.attribute,
                    input,
                    target=predicted_main_class,
                    n_perturb_samples=n_perturb_samples,
                )
            sens_value = sens.cpu().detach().numpy()[0]
        else:
            sens_value = 0

        # filter_name = test_dataset.data.iloc[pbar.n]["filter"].split(" ")[0]
        attr_data = attributions.squeeze().cpu().detach().numpy()
        if render:
            fig, ax = viz.visualize_image_attr_multiple(
                np.transpose(attr_data, (1, 2, 0)),
                np.transpose(inv_input.squeeze().cpu().detach().numpy(),
                             (1, 2, 0)),
                ["original_image", "heat_map"],
                ["all", "positive"],
                titles=["original_image", "heat_map"],
                cmap=default_cmap,
                show_colorbar=True,
                use_pyplot=False,
                fig_size=(8, 6),
            )
            if use_sensitivity or use_infidelity:
                ax[0].set_xlabel(
                    f"Infidelity: {'{0:.6f}'.format(inf_value)}\n Sensitivity: {'{0:.6f}'.format(sens_value)}"
                )
            fig.suptitle(
                f"True: {classes_map[str(label.numpy()[0])][0]}, Pred: {classes_map[str(pred_label_idx.item())][0]}\nScore: {'{0:.4f}'.format(prediction_score)}",
                fontsize=16,
            )
            fig.savefig(
                os.path.join(
                    out_folder,
                    f"{str(idx)}-{str(filter_count)}-{str(label.numpy()[0])}-{str(OUR_FILTERS[filter_count])}-{classes_map[str(label.numpy()[0])][0]}-{classes_map[str(pred_label_idx.item())][0]}.png",
                ))
            plt.close(fig)
        # if pbar.n > 25:
        #     break
        score_for_true_label = output.cpu().detach().numpy(
        )[0][predicted_main_class]

        filter_attrs[OUR_FILTERS[filter_count]] = [
            np.moveaxis(attr_data, 0, -1),
            "{0:.8f}".format(score_for_true_label),
        ]

        data_range_for_current_set = MAX_ATT_VALUES[model_version][method][
            dataset]
        filter_count += 1
        if filter_count >= len(OUR_FILTERS):
            ssims = []
            for rot in OUR_FILTERS:
                ssims.append("{0:.8f}".format(
                    ssim(
                        filter_attrs["none"][0],
                        filter_attrs[rot][0],
                        win_size=11,
                        data_range=data_range_for_current_set,
                        multichannel=True,
                    )))
                ssims.append(filter_attrs[rot][1])

            scores.append(ssims)
            filter_count = 0
            predicted_main_class = 0
            idx += 1

    pbar.close()

    indexes = []

    for filter_name in OUR_FILTERS:
        indexes.append(str(filter_name) + "-ssim")
        indexes.append(str(filter_name) + "-score")
    np.savetxt(
        os.path.join(
            out_folder,
            f"{model_version}-{dataset}-{method}-ssim-with-range.csv"),
        np.array(scores),
        delimiter=";",
        fmt="%s",
        header=";".join([str(rot) for rot in indexes]),
    )

    print(f"Artifacts stored at {out_folder}")
Beispiel #22
0
    saliency = Saliency(model)
    grads = saliency.attribute(x, target=labels[idx].item())
    grads_sal.append(grads.squeeze().cpu().detach().numpy())

    # Occlusion
    occlusion = Occlusion(model)
    grads = occlusion.attribute(x, strides = (1, int(FS / 100)), target=labels[idx].item(), sliding_window_shapes=(1, int(FS / 10)), baselines=0)
    grads_occ.append(grads.squeeze().cpu().detach().numpy())

    # Integrated Gradients
    integrated_gradients = IntegratedGradients(model)
    grads = integrated_gradients.attribute(x, target=labels[idx].item(), n_steps=1000)
    grads_igrad.append(grads.squeeze().cpu().detach().numpy())

    # Gradient SHAP
    gshap = GradientShap(model)
    baseline_dist = torch.cat([x*0, x*1])
    grads = gshap.attribute(x, n_samples=10, stdevs=0.1, baselines=baseline_dist, target=labels[idx].item())
    grads_gshap.append(grads.squeeze().cpu().detach().numpy())

    # DeepLIFT
    dlift = DeepLift(model)
    grads = dlift.attribute(x, x*0, target=labels[idx].item())
    grads_dlift.append(grads.squeeze().cpu().detach().numpy())

    signal.append(x.squeeze().cpu().detach().numpy())

with open(os.path.join('results', 'interp_' + os.path.basename(MODEL) + '.pk'), 'wb') as hf:
    pk.dump({'x': signal, 'sal': grads_sal, 'occ': grads_occ, 'igrad': grads_igrad, 'gshap': grads_gshap, 'dlift': grads_dlift}, hf)

def main(args, DatasetsTypes, DataGenerationTypes, models, device):
    for m in range(len(models)):

        for x in range(len(DatasetsTypes)):
            for y in range(len(DataGenerationTypes)):

                if (DataGenerationTypes[y] == None):
                    args.DataName = DatasetsTypes[x] + "_Box"
                else:
                    args.DataName = DatasetsTypes[
                        x] + "_" + DataGenerationTypes[y]

                Training = np.load(args.data_dir + "SimulatedTrainingData_" +
                                   args.DataName + "_F_" +
                                   str(args.NumFeatures) + "_TS_" +
                                   str(args.NumTimeSteps) + ".npy")
                TrainingMetaDataset = np.load(args.data_dir +
                                              "SimulatedTrainingMetaData_" +
                                              args.DataName + "_F_" +
                                              str(args.NumFeatures) + "_TS_" +
                                              str(args.NumTimeSteps) + ".npy")
                TrainingLabel = TrainingMetaDataset[:, 0]

                Testing = np.load(args.data_dir + "SimulatedTestingData_" +
                                  args.DataName + "_F_" +
                                  str(args.NumFeatures) + "_TS_" +
                                  str(args.NumTimeSteps) + ".npy")
                TestingDataset_MetaData = np.load(args.data_dir +
                                                  "SimulatedTestingMetaData_" +
                                                  args.DataName + "_F_" +
                                                  str(args.NumFeatures) +
                                                  "_TS_" +
                                                  str(args.NumTimeSteps) +
                                                  ".npy")
                TestingLabel = TestingDataset_MetaData[:, 0]

                Training = Training.reshape(
                    Training.shape[0], Training.shape[1] * Training.shape[2])
                Testing = Testing.reshape(Testing.shape[0],
                                          Testing.shape[1] * Testing.shape[2])

                scaler = MinMaxScaler()
                scaler.fit(Training)
                Training = scaler.transform(Training)
                Testing = scaler.transform(Testing)

                TrainingRNN = Training.reshape(Training.shape[0],
                                               args.NumTimeSteps,
                                               args.NumFeatures)
                TestingRNN = Testing.reshape(Testing.shape[0],
                                             args.NumTimeSteps,
                                             args.NumFeatures)

                train_dataRNN = data_utils.TensorDataset(
                    torch.from_numpy(TrainingRNN),
                    torch.from_numpy(TrainingLabel))
                train_loaderRNN = data_utils.DataLoader(
                    train_dataRNN, batch_size=args.batch_size, shuffle=True)

                test_dataRNN = data_utils.TensorDataset(
                    torch.from_numpy(TestingRNN),
                    torch.from_numpy(TestingLabel))
                test_loaderRNN = data_utils.DataLoader(
                    test_dataRNN, batch_size=args.batch_size, shuffle=False)

                modelName = "Simulated"
                modelName += args.DataName

                saveModelName = "../Models/" + models[m] + "/" + modelName
                saveModelBestName = saveModelName + "_BEST.pkl"

                pretrained_model = torch.load(saveModelBestName,
                                              map_location=device)
                Test_Acc = checkAccuracy(test_loaderRNN, pretrained_model,
                                         args)
                print('{} {} model BestAcc {:.4f}'.format(
                    args.DataName, models[m], Test_Acc))

                if (Test_Acc >= 90):

                    if (args.GradFlag):
                        rescaledGrad = np.zeros((TestingRNN.shape))
                        Grad = Saliency(pretrained_model)

                    if (args.IGFlag):
                        rescaledIG = np.zeros((TestingRNN.shape))
                        IG = IntegratedGradients(pretrained_model)
                    if (args.DLFlag):
                        rescaledDL = np.zeros((TestingRNN.shape))
                        DL = DeepLift(pretrained_model)
                    if (args.GSFlag):
                        rescaledGS = np.zeros((TestingRNN.shape))
                        GS = GradientShap(pretrained_model)
                    if (args.DLSFlag):
                        rescaledDLS = np.zeros((TestingRNN.shape))
                        DLS = DeepLiftShap(pretrained_model)

                    if (args.SGFlag):
                        rescaledSG = np.zeros((TestingRNN.shape))
                        Grad_ = Saliency(pretrained_model)
                        SG = NoiseTunnel(Grad_)

                    if (args.ShapleySamplingFlag):
                        rescaledShapleySampling = np.zeros((TestingRNN.shape))
                        SS = ShapleyValueSampling(pretrained_model)
                    if (args.GSFlag):
                        rescaledFeaturePermutation = np.zeros(
                            (TestingRNN.shape))
                        FP = FeaturePermutation(pretrained_model)
                    if (args.FeatureAblationFlag):
                        rescaledFeatureAblation = np.zeros((TestingRNN.shape))
                        FA = FeatureAblation(pretrained_model)

                    if (args.OcclusionFlag):
                        rescaledOcclusion = np.zeros((TestingRNN.shape))
                        OS = Occlusion(pretrained_model)

                    idx = 0
                    mask = np.zeros((args.NumTimeSteps, args.NumFeatures),
                                    dtype=int)
                    for i in range(args.NumTimeSteps):
                        mask[i, :] = i

                    for i, (samples, labels) in enumerate(test_loaderRNN):

                        print('[{}/{}] {} {} model accuracy {:.2f}'\
                                .format(i,len(test_loaderRNN), models[m], args.DataName, Test_Acc))

                        input = samples.reshape(-1, args.NumTimeSteps,
                                                args.NumFeatures).to(device)
                        input = Variable(input,
                                         volatile=False,
                                         requires_grad=True)

                        batch_size = input.shape[0]
                        baseline_single = torch.from_numpy(
                            np.random.random(input.shape)).to(device)
                        baseline_multiple = torch.from_numpy(
                            np.random.random(
                                (input.shape[0] * 5, input.shape[1],
                                 input.shape[2]))).to(device)
                        inputMask = np.zeros((input.shape))
                        inputMask[:, :, :] = mask
                        inputMask = torch.from_numpy(inputMask).to(device)
                        mask_single = torch.from_numpy(mask).to(device)
                        mask_single = mask_single.reshape(
                            1, args.NumTimeSteps, args.NumFeatures).to(device)
                        labels = torch.tensor(labels.int().tolist()).to(device)

                        if (args.GradFlag):
                            attributions = Grad.attribute(input, \
                                                          target=labels)
                            rescaledGrad[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.IGFlag):
                            attributions = IG.attribute(input,  \
                                                        baselines=baseline_single, \
                                                        target=labels)
                            rescaledIG[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.DLFlag):
                            attributions = DL.attribute(input,  \
                                                        baselines=baseline_single, \
                                                        target=labels)
                            rescaledDL[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.GSFlag):

                            attributions = GS.attribute(input,  \
                                                        baselines=baseline_multiple, \
                                                        stdevs=0.09,\
                                                        target=labels)
                            rescaledGS[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.DLSFlag):

                            attributions = DLS.attribute(input,  \
                                                        baselines=baseline_multiple, \
                                                        target=labels)
                            rescaledDLS[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.SGFlag):
                            attributions = SG.attribute(input, \
                                                        target=labels)
                            rescaledSG[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.ShapleySamplingFlag):
                            attributions = SS.attribute(input, \
                                            baselines=baseline_single, \
                                            target=labels,\
                                            feature_mask=inputMask)
                            rescaledShapleySampling[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.FeaturePermutationFlag):
                            attributions = FP.attribute(input, \
                                            target=labels,
                                            perturbations_per_eval= input.shape[0],\
                                            feature_mask=mask_single)
                            rescaledFeaturePermutation[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.FeatureAblationFlag):
                            attributions = FA.attribute(input, \
                                            target=labels)
                            # perturbations_per_eval= input.shape[0],\
                            # feature_mask=mask_single)
                            rescaledFeatureAblation[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.OcclusionFlag):
                            attributions = OS.attribute(input, \
                                            sliding_window_shapes=(1,args.NumFeatures),
                                            target=labels,
                                            baselines=baseline_single)
                            rescaledOcclusion[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        idx += batch_size

                    if (args.plot):
                        index = random.randint(0, TestingRNN.shape[0] - 1)
                        plotExampleBox(TestingRNN[index, :, :],
                                       args.Saliency_Maps_graphs_dir +
                                       args.DataName + "_" + models[m] +
                                       '_sample',
                                       flip=True)

                        print("Plotting sample", index)
                        if (args.GradFlag):
                            plotExampleBox(rescaledGrad[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_Grad',
                                           greyScale=True,
                                           flip=True)

                        if (args.IGFlag):
                            plotExampleBox(rescaledIG[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_IG',
                                           greyScale=True,
                                           flip=True)

                        if (args.DLFlag):
                            plotExampleBox(rescaledDL[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_DL',
                                           greyScale=True,
                                           flip=True)

                        if (args.GSFlag):
                            plotExampleBox(rescaledGS[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_GS',
                                           greyScale=True,
                                           flip=True)

                        if (args.DLSFlag):
                            plotExampleBox(rescaledDLS[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_DLS',
                                           greyScale=True,
                                           flip=True)

                        if (args.SGFlag):
                            plotExampleBox(rescaledSG[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_SG',
                                           greyScale=True,
                                           flip=True)

                        if (args.ShapleySamplingFlag):
                            plotExampleBox(
                                rescaledShapleySampling[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_ShapleySampling',
                                greyScale=True,
                                flip=True)

                        if (args.FeaturePermutationFlag):
                            plotExampleBox(
                                rescaledFeaturePermutation[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_FeaturePermutation',
                                greyScale=True,
                                flip=True)

                        if (args.FeatureAblationFlag):
                            plotExampleBox(
                                rescaledFeatureAblation[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_FeatureAblation',
                                greyScale=True,
                                flip=True)

                        if (args.OcclusionFlag):
                            plotExampleBox(rescaledOcclusion[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_Occlusion',
                                           greyScale=True,
                                           flip=True)

                    if (args.save):
                        if (args.GradFlag):
                            print("Saving Grad", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_Grad_rescaled", rescaledGrad)

                        if (args.IGFlag):
                            print("Saving IG", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_IG_rescaled", rescaledIG)

                        if (args.DLFlag):
                            print("Saving DL", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_DL_rescaled", rescaledDL)

                        if (args.GSFlag):
                            print("Saving GS", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_GS_rescaled", rescaledGS)

                        if (args.DLSFlag):
                            print("Saving DLS", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_DLS_rescaled", rescaledDLS)

                        if (args.SGFlag):
                            print("Saving SG", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_SG_rescaled", rescaledSG)

                        if (args.ShapleySamplingFlag):
                            print("Saving ShapleySampling",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_ShapleySampling_rescaled",
                                rescaledShapleySampling)

                        if (args.FeaturePermutationFlag):
                            print("Saving FeaturePermutation",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_FeaturePermutation_rescaled",
                                rescaledFeaturePermutation)

                        if (args.FeatureAblationFlag):
                            print("Saving FeatureAblation",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_FeatureAblation_rescaled",
                                rescaledFeatureAblation)

                        if (args.OcclusionFlag):
                            print("Saving Occlusion",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_Occlusion_rescaled",
                                rescaledOcclusion)

                else:
                    logging.basicConfig(filename=args.log_file,
                                        level=logging.DEBUG)

                    logging.debug('{} {} model BestAcc {:.4f}'.format(
                        args.DataName, models[m], Test_Acc))

                    if not os.path.exists(args.ignore_list):
                        with open(args.ignore_list, 'w') as fp:
                            fp.write(args.DataName + '_' + models[m] + '\n')

                    else:
                        with open(args.ignore_list, "a") as fp:
                            fp.write(args.DataName + '_' + models[m] + '\n')