示例#1
0
def get_explanation(generated_data, discriminator, prediction, XAItype="shap", cuda=True, trained_data=None,
                    data_type="mnist") -> None:
    """
    This function calculates the explanation for given generated images using the desired xAI systems and the
    :param generated_data: data created by the generator
    :type generated_data: torch.Tensor
    :param discriminator: the discriminator model
    :type discriminator: torch.nn.Module
    :param prediction: tensor of predictions by the discriminator on the generated data
    :type prediction: torch.Tensor
    :param XAItype: the type of xAI system to use. One of ("shap", "lime", "saliency")
    :type XAItype: str
    :param cuda: whether to use gpu
    :type cuda: bool
    :param trained_data: a batch from the dataset
    :type trained_data: torch.Tensor
    :param data_type: the type of the dataset used. One of ("cifar", "mnist", "fmnist")
    :type data_type: str
    :return:
    :rtype:
    """

    # initialize temp values to all 1s
    temp = values_target(size=generated_data.size(), value=1.0, cuda=cuda)

    # mask values with low prediction
    mask = (prediction < 0.5).view(-1)
    indices = (mask.nonzero(as_tuple=False)).detach().cpu().numpy().flatten().tolist()

    data = generated_data[mask, :]

    if len(indices) > 1:
        if XAItype == "saliency":
            for i in range(len(indices)):
                explainer = Saliency(discriminator)
                temp[indices[i], :] = explainer.attribute(data[i, :].detach().unsqueeze(0))

        elif XAItype == "shap":
            for i in range(len(indices)):
                explainer = DeepLiftShap(discriminator)
                temp[indices[i], :] = explainer.attribute(data[i, :].detach().unsqueeze(0), trained_data, target=0)

        elif XAItype == "lime":
            explainer = lime_image.LimeImageExplainer()
            global discriminatorLime
            discriminatorLime = deepcopy(discriminator)
            discriminatorLime.cpu()
            discriminatorLime.eval()
            for i in range(len(indices)):
                if data_type == "cifar":
                    tmp = data[i, :].detach().cpu().numpy()
                    tmp = np.reshape(tmp, (32, 32, 3)).astype(np.double)
                    exp = explainer.explain_instance(tmp, batch_predict_cifar, num_samples=100)
                else:
                    tmp = data[i, :].squeeze().detach().cpu().numpy().astype(np.double)
                    exp = explainer.explain_instance(tmp, batch_predict, num_samples=100)
                _, mask = exp.get_image_and_mask(exp.top_labels[0], positive_only=False, negative_only=False)
                temp[indices[i], :] = torch.tensor(mask.astype(np.float))
            del discriminatorLime
        else:
            raise Exception("wrong xAI type given")

    if cuda:
        temp = temp.cuda()
    set_values(normalize_vector(temp))
def main(args):

    train_loader, test_loader = data_generator(args.data_dir,1)

    for m in range(len(models)):

        model_name = "model_{}_NumFeatures_{}".format(models[m],args.NumFeatures)
        model_filename = args.model_dir + 'm_' + model_name + '.pt'
        pretrained_model = torch.load(open(model_filename, "rb"),map_location=device) 
        pretrained_model.to(device)



        if(args.GradFlag):
            Grad = Saliency(pretrained_model)
        if(args.IGFlag):
            IG = IntegratedGradients(pretrained_model)
        if(args.DLFlag):
            DL = DeepLift(pretrained_model)
        if(args.GSFlag):
            GS = GradientShap(pretrained_model)
        if(args.DLSFlag):
            DLS = DeepLiftShap(pretrained_model)                 
        if(args.SGFlag):
            Grad_ = Saliency(pretrained_model)
            SG = NoiseTunnel(Grad_)
        if(args.ShapleySamplingFlag):
            SS = ShapleyValueSampling(pretrained_model)
        if(args.GSFlag):
            FP = FeaturePermutation(pretrained_model)
        if(args.FeatureAblationFlag):
            FA = FeatureAblation(pretrained_model)         
        if(args.OcclusionFlag):
            OS = Occlusion(pretrained_model)

        timeMask=np.zeros((args.NumTimeSteps, args.NumFeatures),dtype=int)
        featureMask=np.zeros((args.NumTimeSteps, args.NumFeatures),dtype=int)
        for i in  range (args.NumTimeSteps):
            timeMask[i,:]=i

        for i in  range (args.NumTimeSteps):
            featureMask[:,i]=i

        indexes = [[] for i in range(5,10)]
        for i ,(data, target) in enumerate(test_loader):
            if(target==5 or target==6 or target==7 or target==8 or target==9):
                index=target-5

                if(len(indexes[index])<1):
                    indexes[index].append(i)
        for j, index in enumerate(indexes):
            print(index)
        # indexes = [[21],[17],[84],[9]]

        for j, index in enumerate(indexes):
            print("Getting Saliency for number", j+1)
            for i, (data, target) in enumerate(test_loader):
                if(i in index):
                        
                    labels =  target.to(device)
             
                    input = data.reshape(-1, args.NumTimeSteps, args.NumFeatures).to(device)
                    input = Variable(input,  volatile=False, requires_grad=True)

                    baseline_single=torch.Tensor(np.random.random(input.shape)).to(device)
                    baseline_multiple=torch.Tensor(np.random.random((input.shape[0]*5,input.shape[1],input.shape[2]))).to(device)
                    inputMask= np.zeros((input.shape))
                    inputMask[:,:,:]=timeMask
                    inputMask =torch.Tensor(inputMask).to(device)
                    mask_single= torch.Tensor(timeMask).to(device)
                    mask_single=mask_single.reshape(1,args.NumTimeSteps, args.NumFeatures).to(device)

                    Data=data.reshape(args.NumTimeSteps, args.NumFeatures).data.cpu().numpy()
                    
                    target_=int(target.data.cpu().numpy()[0])

                    plotExampleBox(Data,args.Graph_dir+'Sample_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.GradFlag):
                        attributions = Grad.attribute(input, \
                                                      target=labels)
                        
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_Grad_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(Grad,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=None)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_Grad_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)



                    if(args.IGFlag):
                        attributions = IG.attribute(input,  \
                                                    baselines=baseline_single, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_IG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(IG,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_IG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.DLFlag):
                        attributions = DL.attribute(input,  \
                                                    baselines=baseline_single, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_DL_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(DL,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_DL_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)




                    if(args.GSFlag):

                        attributions = GS.attribute(input,  \
                                                    baselines=baseline_multiple, \
                                                    stdevs=0.09,\
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_GS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)

 
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(GS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_multiple)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_GS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                    if(args.DLSFlag):

                        attributions = DLS.attribute(input,  \
                                                    baselines=baseline_multiple, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_DLS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(DLS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_multiple)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_DLS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)



                    if(args.SGFlag):
                        attributions = SG.attribute(input, \
                                                    target=labels)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_SG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(SG,input, args.NumFeatures,args.NumTimeSteps, labels)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_SG_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)


                    if(args.ShapleySamplingFlag):
                        attributions = SS.attribute(input, \
                                        baselines=baseline_single, \
                                        target=labels,\
                                        feature_mask=inputMask)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_SVS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(SS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single,hasFeatureMask=inputMask)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_SVS_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                    # if(args.FeaturePermutationFlag):
                    #     attributions = FP.attribute(input, \
                    #                     target=labels),
                    #                     # perturbations_per_eval= 1,\
                    #                     # feature_mask=mask_single)
                    #     saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                    #     plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FP',greyScale=True)


                    if(args.FeatureAblationFlag):
                        attributions = FA.attribute(input, \
                                        target=labels)
                                        # perturbations_per_eval= input.shape[0],\
                                        # feature_mask=mask_single)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)
                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FA_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(FA,input, args.NumFeatures,args.NumTimeSteps, labels)
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_FA_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)

                    if(args.OcclusionFlag):
                        attributions = OS.attribute(input, \
                                        sliding_window_shapes=(1,int(args.NumFeatures/10)),
                                        target=labels,
                                        baselines=baseline_single)
                        saliency_=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,attributions)

                        plotExampleBox(saliency_[0],args.Graph_dir+models[m]+'_FO_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
                        if(args.TSRFlag):
                            TSR_attributions =  getTwoStepRescaling(OS,input, args.NumFeatures,args.NumTimeSteps, labels,hasBaseline=baseline_single,hasSliding_window_shapes= (1,int(args.NumFeatures/10)))
                            TSR_saliency=Helper.givenAttGetRescaledSaliency(args.NumTimeSteps, args.NumFeatures,TSR_attributions,isTensor=False)
                            plotExampleBox(TSR_saliency,args.Graph_dir+models[m]+'_TSR_FO_MNIST_'+str(target_)+'_index_'+str(i+1),greyScale=True)
示例#3
0
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        #     plt.axis('off')
        if i == 1:
            ax.set_ylabel(r"$s'$",
                          rotation=0,
                          fontsize=10,
                          fontfamily="Times New Roman")
            ax.yaxis.set_label_coords(-0.22, 0.4)

        input = np.array((obs, next_obs)).astype(np.float32)
        input = th.tensor(np.expand_dims(input, axis=0)).to('cpu').to(dtype)
        input.requires_grad = True
        attributions = sal.attribute(input)
        attributions = np.abs(attributions.detach()[0, ...])

        ax = plt.subplot(3, w, w + i)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        ax.imshow(attributions[0, ...],
                  cmap='gray',
                  vmin=attributions.min(),
                  vmax=attributions.max())
        if i == 1:
            ax.set_ylabel(r"$\left\vert \frac{dR}{ds} \right\vert$",
                          rotation=0,
                          fontsize=10,
示例#4
0
grads_sal = list()
grads_igrad = list()
grads_occ = list()
grads_gshap = list()
grads_dlift = list()
signal = list()

for idx in range(16):
    x = signals[idx].float().unsqueeze(0)
    x.requires_grad = True

    model.eval()

    # Saliency
    saliency = Saliency(model)
    grads = saliency.attribute(x, target=labels[idx].item())
    grads_sal.append(grads.squeeze().cpu().detach().numpy())

    # Occlusion
    occlusion = Occlusion(model)
    grads = occlusion.attribute(x,
                                strides=(1, int(FS / 100)),
                                target=labels[idx].item(),
                                sliding_window_shapes=(1, int(FS / 10)),
                                baselines=0)
    grads_occ.append(grads.squeeze().cpu().detach().numpy())

    # Integrated Gradients
    integrated_gradients = IntegratedGradients(model)
    grads = integrated_gradients.attribute(x,
                                           target=labels[idx].item(),
示例#5
0
 def extract_Sa(self, X_test):
     saliency = Saliency(self.net)
     start = time.time()
     saliency_attr_test = saliency.attribute(X_test.to(self.device))
     print("temps train", time.time() - start)
     return saliency_attr_test.detach().cpu().numpy()
示例#6
0
def visualize_maps(
        model: torch.nn.Module,
        inputs: Union[Tuple[torch.Tensor, torch.Tensor]],
        labels: torch.Tensor,
        title: str,
        second_occlusion: Tuple[int, int, int] = (1, 2, 2),
        baselines: Tuple[int, int] = (0, 0),
        closest: bool = False,
) -> None:
    """
    Visualizes the average of the inputs, or the single input, using various different XAI approaches
    """
    single = inputs[1].ndim == 2
    model.zero_grad()
    model.eval()
    occ = Occlusion(model)
    saliency = Saliency(model)
    saliency = NoiseTunnel(saliency)
    igrad = IntegratedGradients(model)
    igrad_2 = NoiseTunnel(igrad)
    # deep_lift = DeepLift(model)
    grad_shap = ShapleyValueSampling(model)
    output = model(inputs[0], inputs[1])
    output = F.softmax(output, dim=-1).argmax(dim=1, keepdim=True)
    labels = F.softmax(labels, dim=-1).argmax(dim=1, keepdim=True)
    if np.all(labels.cpu().numpy() == 1) and not closest:
        return
    if True:
        targets = labels
    else:
        targets = output
    print(targets)
    correct = targets.cpu().numpy() == labels.cpu().numpy()
    # if correct:
    #   return
    occ_out = occ.attribute(
        inputs,
        baselines=baselines,
        sliding_window_shapes=((1, 5, 5), second_occlusion),
        target=targets,
    )
    # occ_out2 = occ.attribute(inputs, sliding_window_shapes=((1,20,20), second_occlusion), strides=(8,1), target=targets)
    saliency_out = saliency.attribute(inputs,
                                      nt_type="smoothgrad_sq",
                                      n_samples=5,
                                      target=targets,
                                      abs=False)
    # igrad_out = igrad.attribute(inputs, target=targets, internal_batch_size=1)
    igrad_out = igrad_2.attribute(
        inputs,
        baselines=baselines,
        target=targets,
        n_samples=5,
        nt_type="smoothgrad_sq",
        internal_batch_size=1,
    )
    # deep_lift_out = deep_lift.attribute(inputs, target=targets)
    grad_shap_out = grad_shap.attribute(inputs,
                                        baselines=baselines,
                                        target=targets)

    if single:
        inputs = convert_to_image(inputs)
        occ_out = convert_to_image(occ_out)
        saliency_out = convert_to_image(saliency_out)
        igrad_out = convert_to_image(igrad_out)
        # grad_shap_out = convert_to_image(grad_shap_out)
    else:
        inputs = convert_to_image_multi(inputs)
        occ_out = convert_to_image_multi(occ_out)
        saliency_out = convert_to_image_multi(saliency_out)
        igrad_out = convert_to_image_multi(igrad_out)
        grad_shap_out = convert_to_image_multi(grad_shap_out)
    fig, axes = plt.subplots(2, 5)
    (fig, axes[0, 0]) = visualization.visualize_image_attr(
        occ_out[0][0],
        inputs[0][0],
        title="Original Image",
        method="original_image",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 0]),
        use_pyplot=False,
    )
    (fig, axes[0, 1]) = visualization.visualize_image_attr(
        occ_out[0][0],
        None,
        sign="all",
        title="Occ (5x5)",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 1]),
        use_pyplot=False,
    )
    (fig, axes[0, 2]) = visualization.visualize_image_attr(
        saliency_out[0][0],
        None,
        sign="all",
        title="Saliency",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 2]),
        use_pyplot=False,
    )
    (fig, axes[0, 3]) = visualization.visualize_image_attr(
        igrad_out[0][0],
        None,
        sign="all",
        title="Integrated Grad",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 3]),
        use_pyplot=False,
    )
    (fig, axes[0, 4]) = visualization.visualize_image_attr(
        grad_shap_out[0],
        None,
        title="GradSHAP",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[0, 4]),
        use_pyplot=False,
    )
    ##### Second Input Labels #########################################################################################
    (fig, axes[1, 0]) = visualization.visualize_image_attr(
        occ_out[1],
        inputs[1],
        title="Original Aux",
        method="original_image",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 0]),
        use_pyplot=False,
    )
    (fig, axes[1, 1]) = visualization.visualize_image_attr(
        occ_out[1],
        None,
        sign="all",
        title="Occ (1x1)",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 1]),
        use_pyplot=False,
    )
    (fig, axes[1, 2]) = visualization.visualize_image_attr(
        saliency_out[1],
        None,
        sign="all",
        title="Saliency",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 2]),
        use_pyplot=False,
    )
    (fig, axes[1, 3]) = visualization.visualize_image_attr(
        igrad_out[1],
        None,
        sign="all",
        title="Integrated Grad",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 3]),
        use_pyplot=False,
    )
    (fig, axes[1, 4]) = visualization.visualize_image_attr(
        grad_shap_out[1],
        None,
        title="GradSHAP",
        show_colorbar=True,
        plt_fig_axis=(fig, axes[1, 4]),
        use_pyplot=False,
    )

    fig.suptitle(
        title +
        f" Label: {labels.cpu().numpy()} Pred: {targets.cpu().numpy()}")
    plt.savefig(
        f"{title}_{'single' if single else 'multi'}_{'Failed' if correct else 'Success'}_baseline{baselines[0]}.png",
        dpi=300,
    )
    plt.clf()
    plt.cla()
示例#7
0
def run_saliency_methods(saliency_methods,
                         pretrained_model,
                         test_shape,
                         train_loader,
                         test_loader,
                         device,
                         model_type,
                         model_name,
                         saliency_dir,
                         tsr_graph_dir=None,
                         tsr_inputs_to_graph=()):
    _, num_timesteps, num_features = test_shape

    run_grad = "Grad" in saliency_methods
    run_grad_tsr = "Grad_TSR" in saliency_methods
    run_ig = "IG" in saliency_methods
    run_ig_tsr = "IG_TSR" in saliency_methods
    run_dl = "DL" in saliency_methods
    run_gs = "GS" in saliency_methods
    run_dls = "DLS" in saliency_methods
    run_dls_tsr = "DLS_TSR" in saliency_methods
    run_sg = "SG" in saliency_methods
    run_shapley_sampling = "ShapleySampling" in saliency_methods
    run_feature_permutation = "FeaturePermutation" in saliency_methods
    run_feature_ablation = "FeatureAblation" in saliency_methods
    run_occlusion = "Occlusion" in saliency_methods
    run_fit = "FIT" in saliency_methods
    run_ifit = "IFIT" in saliency_methods
    run_wfit = "WFIT" in saliency_methods
    run_iwfit = "IWFIT" in saliency_methods

    if run_grad or run_grad_tsr:
        Grad = Saliency(pretrained_model)
    if run_grad:
        rescaledGrad = np.zeros(test_shape)
    if run_grad_tsr:
        rescaledGrad_TSR = np.zeros(test_shape)

    if run_ig or run_ig_tsr:
        IG = IntegratedGradients(pretrained_model)
    if run_ig:
        rescaledIG = np.zeros(test_shape)
    if run_ig_tsr:
        rescaledIG_TSR = np.zeros(test_shape)

    if run_dl:
        rescaledDL = np.zeros(test_shape)
        DL = DeepLift(pretrained_model)

    if run_gs:
        rescaledGS = np.zeros(test_shape)
        GS = GradientShap(pretrained_model)

    if run_dls or run_dls_tsr:
        DLS = DeepLiftShap(pretrained_model)
    if run_dls:
        rescaledDLS = np.zeros(test_shape)
    if run_dls_tsr:
        rescaledDLS_TSR = np.zeros(test_shape)

    if run_sg:
        rescaledSG = np.zeros(test_shape)
        Grad_ = Saliency(pretrained_model)
        SG = NoiseTunnel(Grad_)

    if run_shapley_sampling:
        rescaledShapleySampling = np.zeros(test_shape)
        SS = ShapleyValueSampling(pretrained_model)

    if run_gs:
        rescaledFeaturePermutation = np.zeros(test_shape)
        FP = FeaturePermutation(pretrained_model)

    if run_feature_ablation:
        rescaledFeatureAblation = np.zeros(test_shape)
        FA = FeatureAblation(pretrained_model)

    if run_occlusion:
        rescaledOcclusion = np.zeros(test_shape)
        OS = Occlusion(pretrained_model)

    if run_fit:
        rescaledFIT = np.zeros(test_shape)
        FIT = FITExplainer(pretrained_model, ft_dim_last=True)
        generator = JointFeatureGenerator(num_features, data='none')
        # TODO: Increase epochs
        FIT.fit_generator(generator, train_loader, test_loader, n_epochs=300)

    if run_ifit:
        rescaledIFIT = np.zeros(test_shape)
    if run_wfit:
        rescaledWFIT = np.zeros(test_shape)
    if run_iwfit:
        rescaledIWFIT = np.zeros(test_shape)

    idx = 0
    mask = np.zeros((num_timesteps, num_features), dtype=int)
    for i in range(num_timesteps):
        mask[i, :] = i

    for i, (samples, labels) in enumerate(test_loader):
        input = samples.reshape(-1, num_timesteps, num_features).to(device)
        input = Variable(input, volatile=False, requires_grad=True)

        batch_size = input.shape[0]
        baseline_single = torch.from_numpy(np.random.random(
            input.shape)).to(device)
        baseline_multiple = torch.from_numpy(
            np.random.random((input.shape[0] * 5, input.shape[1],
                              input.shape[2]))).to(device)
        inputMask = np.zeros((input.shape))
        inputMask[:, :, :] = mask
        inputMask = torch.from_numpy(inputMask).to(device)
        mask_single = torch.from_numpy(mask).to(device)
        mask_single = mask_single.reshape(1, num_timesteps,
                                          num_features).to(device)
        labels = torch.tensor(labels.int().tolist()).to(device)

        if run_grad:
            attributions = Grad.attribute(input, target=labels)
            rescaledGrad[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)
        if run_grad_tsr:
            rescaledGrad_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                Grad,
                input,
                labels,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_Grad_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_ig:
            attributions = IG.attribute(input,
                                        baselines=baseline_single,
                                        target=labels)
            rescaledIG[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)
        if run_ig_tsr:
            rescaledIG_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                IG,
                input,
                labels,
                baseline=baseline_single,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_IG_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_dl:
            attributions = DL.attribute(input,
                                        baselines=baseline_single,
                                        target=labels)
            rescaledDL[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_gs:
            attributions = GS.attribute(input,
                                        baselines=baseline_multiple,
                                        stdevs=0.09,
                                        target=labels)
            rescaledGS[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_dls:
            attributions = DLS.attribute(input,
                                         baselines=baseline_multiple,
                                         target=labels)
            rescaledDLS[idx:idx +
                        batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                            num_timesteps, num_features, attributions)
        if run_dls_tsr:
            rescaledDLS_TSR[idx:idx + batch_size, :, :] = get_tsr_saliency(
                DLS,
                input,
                labels,
                baseline=baseline_multiple,
                graph_dir=tsr_graph_dir,
                graph_name=f'{model_name}_{model_type}_DLS_TSR',
                inputs_to_graph=tsr_inputs_to_graph,
                cur_batch=i)

        if run_sg:
            attributions = SG.attribute(input, target=labels)
            rescaledSG[idx:idx +
                       batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                           num_timesteps, num_features, attributions)

        if run_shapley_sampling:
            attributions = SS.attribute(input,
                                        baselines=baseline_single,
                                        target=labels,
                                        feature_mask=inputMask)
            rescaledShapleySampling[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_feature_permutation:
            attributions = FP.attribute(input,
                                        target=labels,
                                        perturbations_per_eval=input.shape[0],
                                        feature_mask=mask_single)
            rescaledFeaturePermutation[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_feature_ablation:
            attributions = FA.attribute(input, target=labels)
            # perturbations_per_eval= input.shape[0],\
            # feature_mask=mask_single)
            rescaledFeatureAblation[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_occlusion:
            attributions = OS.attribute(input,
                                        sliding_window_shapes=(1,
                                                               num_features),
                                        target=labels,
                                        baselines=baseline_single)
            rescaledOcclusion[
                idx:idx +
                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                    num_timesteps, num_features, attributions)

        if run_fit:
            attributions = torch.from_numpy(FIT.attribute(input, labels))
            rescaledFIT[idx:idx +
                        batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                            num_timesteps, num_features, attributions)

        if run_ifit:
            attributions = torch.from_numpy(
                inverse_fit_attribute(input,
                                      pretrained_model,
                                      ft_dim_last=True))
            rescaledIFIT[idx:idx + batch_size, :, :] = attributions

        if run_wfit:
            attributions = torch.from_numpy(
                wfit_attribute(input,
                               pretrained_model,
                               N=test_shape[1],
                               ft_dim_last=True,
                               single_label=True))
            rescaledWFIT[idx:idx + batch_size, :, :] = attributions

        if run_iwfit:
            attributions = torch.from_numpy(
                wfit_attribute(input,
                               pretrained_model,
                               N=test_shape[1],
                               ft_dim_last=True,
                               single_label=True,
                               inverse=True))
            rescaledIWFIT[idx:idx + batch_size, :, :] = attributions

        idx += batch_size

    if run_grad:
        print("Saving Grad", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_Grad_rescaled",
            rescaledGrad)
    if run_grad_tsr:
        print("Saving Grad_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_Grad_TSR_rescaled", rescaledGrad_TSR)

    if run_ig:
        print("Saving IG", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_IG_rescaled",
                rescaledIG)
    if run_ig_tsr:
        print("Saving IG_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IG_TSR_rescaled",
            rescaledIG_TSR)

    if run_dl:
        print("Saving DL", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_DL_rescaled",
                rescaledDL)

    if run_gs:
        print("Saving GS", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_GS_rescaled",
                rescaledGS)

    if run_dls:
        print("Saving DLS", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_DLS_rescaled",
                rescaledDLS)
    if run_dls_tsr:
        print("Saving DLS_TSR", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_DLS_TSR_rescaled",
            rescaledDLS_TSR)

    if run_sg:
        print("Saving SG", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_SG_rescaled",
                rescaledSG)

    if run_shapley_sampling:
        print("Saving ShapleySampling", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_ShapleySampling_rescaled", rescaledShapleySampling)

    if run_feature_permutation:
        print("Saving FeaturePermutation", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_FeaturePermutation_rescaled", rescaledFeaturePermutation)

    if run_feature_ablation:
        print("Saving FeatureAblation", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_FeatureAblation_rescaled", rescaledFeatureAblation)

    if run_occlusion:
        print("Saving Occlusion", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type +
            "_Occlusion_rescaled", rescaledOcclusion)

    if run_fit:
        print("Saving FIT", model_name + "_" + model_type)
        np.save(saliency_dir + model_name + "_" + model_type + "_FIT_rescaled",
                rescaledFIT)

    if run_ifit:
        print("Saving IFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IFIT_rescaled",
            rescaledIFIT)

    if run_wfit:
        print("Saving WFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_WFIT_rescaled",
            rescaledWFIT)

    if run_iwfit:
        print("Saving IWFIT", model_name + "_" + model_type)
        np.save(
            saliency_dir + model_name + "_" + model_type + "_IWFIT_rescaled",
            rescaledIWFIT)
示例#8
0
def saliency(classifier_model,
             config,
             dataset_features,
             GNNgraph_list,
             current_fold=None,
             cuda=0):
    '''
		:param classifier_model: trained classifier model
		:param config: parsed configuration file of config.yml
		:param dataset_features: a dictionary of dataset features obtained from load_data.py
		:param GNNgraph_list: a list of GNNgraphs obtained from the dataset
		:param current_fold: has no use in this method
		:param cuda: whether to use GPU to perform conversion to tensor
	'''
    # Initialise settings
    config = config
    interpretability_config = config["interpretability_methods"]["saliency"]
    dataset_features = dataset_features

    # Perform Saliency on the classifier model
    sl = Saliency(classifier_model)

    output_for_metrics_calculation = []
    output_for_generating_saliency_map = {}

    # Obtain attribution score for use in qualitative metrics
    tmp_timing_list = []

    for GNNgraph in GNNgraph_list:
        output = {'graph': GNNgraph}
        for _, label in dataset_features["label_dict"].items():
            # Relabel all just in case, may only relabel those that need relabelling
            # if performance is poor
            original_label = GNNgraph.label
            GNNgraph.label = label

            node_feat, n2n, subg = graph_to_tensor(
                [GNNgraph], dataset_features["feat_dim"],
                dataset_features["edge_feat_dim"], cuda)

            start_generation = perf_counter()
            attribution = sl.attribute(node_feat,
                                       additional_forward_args=(n2n, subg,
                                                                [GNNgraph]),
                                       target=label)

            tmp_timing_list.append(perf_counter() - start_generation)
            attribution_score = torch.sum(attribution, dim=1).tolist()
            attribution_score = standardize_scores(attribution_score)

            GNNgraph.label = original_label

            output[label] = attribution_score
        output_for_metrics_calculation.append(output)

    execution_time = sum(tmp_timing_list) / (len(tmp_timing_list))

    # Obtain attribution score for use in generating saliency map for comparison with zero tensors
    if interpretability_config["sample_ids"] is not None:
        if ',' in str(interpretability_config["sample_ids"]):
            sample_graph_id_list = list(
                map(int, interpretability_config["sample_ids"].split(',')))
        else:
            sample_graph_id_list = [int(interpretability_config["sample_ids"])]

        output_for_generating_saliency_map.update({
            "layergradcam_%s_%s" % (str(assign_type), str(label)): []
            for _, label in dataset_features["label_dict"].items()
        })

        for index in range(len(output_for_metrics_calculation)):
            tmp_output = output_for_metrics_calculation[index]
            tmp_label = tmp_output['graph'].label
            if tmp_output['graph'].graph_id in sample_graph_id_list:
                element_name = "layergradcam_%s_%s" % (str(assign_type),
                                                       str(tmp_label))
                output_for_generating_saliency_map[element_name].append(
                    (tmp_output['graph'], tmp_output[tmp_label]))

    elif interpretability_config["number_of_samples"] > 0:
        # Randomly sample from existing list:
        graph_idxes = list(range(len(output_for_metrics_calculation)))
        random.shuffle(graph_idxes)
        output_for_generating_saliency_map.update({
            "saliency_class_%s" % str(label): []
            for _, label in dataset_features["label_dict"].items()
        })

        # Begin appending found samples
        for index in graph_idxes:
            tmp_label = output_for_metrics_calculation[index]['graph'].label
            if len(output_for_generating_saliency_map["saliency_class_%s" % str(tmp_label)]) < \
             interpretability_config["number_of_samples"]:
                output_for_generating_saliency_map[
                    "saliency_class_%s" % str(tmp_label)].append(
                        (output_for_metrics_calculation[index]['graph'],
                         output_for_metrics_calculation[index][tmp_label]))

    return output_for_metrics_calculation, output_for_generating_saliency_map, execution_time
示例#9
0
def captum_routine(model, dataloader, radiusneigh):

    figdist, axdist = plt.subplots()

    model.eval()
    atrs_feat = []

    nhalos, nhalos_compl = 0, 0

    for i, data in enumerate(dataloader):

        data.to(device)
        data.x.requires_grad_(True)

        # Apply the method and get attributions
        if method == "saliency":
            atrmethod = Saliency(
                lambda datax: model_forward(datax, model, data))
        elif method == "intgradients":
            atrmethod = IntegratedGradients(
                lambda datax: model_forward(datax, model, data))  # not working
        attributions = atrmethod.attribute(data.x, target=0)

        atrs_feat.append(attributions.detach().cpu().numpy().mean(axis=0))
        atr_col = np.abs(attributions.detach().cpu().numpy()).mean(axis=1)
        dists = np.sqrt(np.sum(data.pos.detach().cpu().numpy()**2.,
                               axis=1)) * boxsize

        # Scatter plot of saliency vs distance to the center
        sizes = 10.**(data.x[:, 3].detach().cpu().numpy() + 2.)
        indxs = np.argwhere(dists > 0.).reshape(-1)
        axdist.scatter(dists[indxs],
                       atr_col[indxs],
                       s=sizes[indxs] * 0.2,
                       c="blue",
                       alpha=0.3)
        #axdist.scatter(dists[indxs], sizes[indxs], c=atr_col[indxs], alpha=0.5, vmin=0., vmax=0.3)

        out = model(data)
        y_out = out[:, 0]
        err = (y_out.reshape(-1) - data.y) / data.y
        err = np.abs(err.detach().cpu().numpy())

        # Plot saliency graph
        for ibatch in range(batch_size):
            data.x.requires_grad_(False)

            # Choose only the subhalos of the graph within the batch
            indexes = np.argwhere(ibatch == data.batch).reshape(-1)
            datagraph = data.x[indexes]
            if indexes.shape[0] > 0:  # Avoid possible segmentation fault

                edge_index = radius_graph(datagraph[:, :3], r=radiusneigh)

                num_nodes, num_edges = datagraph.shape[0], edge_index.shape[
                    1] // 2  # Divide by 2 since edges are counted doubled if not directed

                if num_edges == num_nodes * (num_nodes - 1) // 2:
                    nhalos_compl += 1
                nhalos += 1

                if err[ibatch] < 0.015 and datagraph.shape[0] >= 10:
                    #if datagraph.shape[0]>=10 and datagraph.shape[0]<15:
                    #print("Ind", ibatch, "Error",err[ibatch])

                    visualize_points_3D(datagraph, ibatch, atr_col[indexes],
                                        edge_index)

    print("Number of graphs:", nhalos, "out of which ", nhalos_compl,
          "are complete. Fraction:",
          float(nhalos_compl) / float(nhalos))

    axdist.set_ylabel("Saliency")
    axdist.set_xlabel("Distance [kpc/h]")
    #axdist.set_xscale("log")
    #axdist.set_yscale("log")
    figdist.savefig("Plots/distance_attribute_" + method + ".pdf")
    plt.close(figdist)

    # Feature importance plot
    importances = np.abs(np.array(atrs_feat).mean(0))
    feature_names = [r"$x$", r"$y$", r"$z$", r"$M_*$", r"$v$", r"$R_*$"]

    np.savetxt("Outputs/feature_importance.txt", importances)
    feature_importance_plot(importances, feature_names, method)
            mode = 'L'

        # load each image to PIL
        imgs = [Image.fromarray(img, mode=mode) for img in img_list]
        # apply the stored transform to the input
        imgs = [transform(img) for img in imgs]


        im_tensor = torch.stack(imgs)

        print(im_tensor.shape)
        n_img = torch.tensor([im_tensor.shape[0]], requires_grad=False)
        # wrap in a class that allows saliency readout
        wrapped_model = ModelReadoutWrapper(model)
        saliency = Saliency(wrapped_model)
        grads = saliency.attribute(inputs=im_tensor, target=0, additional_forward_args=n_img)
        # now the model has been invoked and the results are available
        results = wrapped_model.format_readout()
        norm_grads = []

        for i in range(len(grads)):
            # get the original image
            original_image = np.transpose((imgs[i].cpu().detach().numpy() / 2) + 0.5, (1, 2, 0))
            # get the gradients for this image
            img_grad = grads[i, :, :, :]
            # shape must be h,w,c, move first dim to the end
            img_grad = img_grad.permute((1, 2, 0))
            norm_grad = _normalize_image_attr(img_grad.numpy(), sign="all", outlier_perc=2)
            norm_grads.append(norm_grad)
            if save_png:
                att_score = results['attention_outputs'][0][i]
    torch.save(seq_state,
               options.out + "/all_mammal_20S_digestion_sequence_mod.pt")
    torch.save(motif_state,
               options.out + "/all_mammal_20S_digestion_motif_mod.pt")

## identify feature importance
# define also
## identify feature importance
saliency = Saliency(motif_model)

pos_saliency_loader = torch.utils.data.DataLoader(pos_train,
                                                  batch_size=len(pos_train))
pos_saliency = next(iter(pos_saliency_loader))
pos_grads = saliency.attribute(
    (pos_saliency[0][:, :, 22:].type(dtype), pos_saliency[1].type(dtype),
     pos_saliency[2].type(dtype)),
    target=1,
    abs=False)

neg_saliency_loader = torch.utils.data.DataLoader(neg_train,
                                                  batch_size=len(neg_train))

neg_saliency = next(iter(neg_saliency_loader))
neg_grads = saliency.attribute(
    (neg_saliency[0][:, :, 22:].type(dtype), neg_saliency[1].type(dtype),
     neg_saliency[2].type(dtype)),
    target=0,
    abs=False)

pos_c_flag = [i == 1 for i in pos_saliency[1]]
pos_i_flag = [i == 0 for i in pos_saliency[1]]
示例#12
0
文件: attribute.py 项目: nilsec/scam
def get_attribution(real_img, 
                    fake_img, 
                    real_class, 
                    fake_class, 
                    net_module, 
                    checkpoint_path, 
                    input_shape, 
                    channels,
                    methods=["ig", "grads", "gc", "ggc", "dl", "ingrad", "random", "residual"],
                    output_classes=6,
                    downsample_factors=[(2,2), (2,2), (2,2), (2,2)]):


    imgs = [image_to_tensor(normalize_image(real_img).astype(np.float32)), 
            image_to_tensor(normalize_image(fake_img).astype(np.float32))]

    classes = [real_class, fake_class]
    net = init_network(checkpoint_path, input_shape, net_module, channels, output_classes=output_classes,eval_net=True, require_grad=False,
                       downsample_factors=downsample_factors)

    attrs = []
    attrs_names = []

    if "residual" in methods:
        res = np.abs(real_img - fake_img)
        res = res - np.min(res)
        attrs.append(torch.tensor(res/np.max(res)))
        attrs_names.append("residual")

    if "random" in methods:
        rand = np.abs(np.random.randn(*np.shape(real_img)))
        rand = np.abs(scipy.ndimage.filters.gaussian_filter(rand, 4))
        rand = rand - np.min(rand)
        rand = rand/np.max(np.abs(rand))
        attrs.append(torch.tensor(rand))
        attrs_names.append("random")

    if "gc" in methods:
        net.zero_grad()
        last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1]
        layer_name = last_conv_layer[0]
        layer = last_conv_layer[1]
        layer_gc = LayerGradCam(net, layer)
        gc_real = layer_gc.attribute(imgs[0], target=classes[0])
        gc_fake = layer_gc.attribute(imgs[1], target=classes[1])

        gc_real = project_layer_activations_to_input_rescale(gc_real.cpu().detach().numpy(), (input_shape[0], input_shape[1]))
        gc_fake = project_layer_activations_to_input_rescale(gc_fake.cpu().detach().numpy(), (input_shape[0], input_shape[1]))

        attrs.append(torch.tensor(gc_real[0,0,:,:]))
        attrs_names.append("gc_real")

        attrs.append(torch.tensor(gc_fake[0,0,:,:]))
        attrs_names.append("gc_fake")

        # SCAM
        gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, 
                                     fake_class, net_module, checkpoint_path, 
                                     input_shape, channels, None, output_classes=output_classes,
                                     downsample_factors=downsample_factors)
        attrs.append(gc_diff_0)
        attrs_names.append("gc_diff_0")

        attrs.append(gc_diff_1)
        attrs_names.append("gc_diff_1")

    if "ggc" in methods:
        net.zero_grad()
        last_conv = [module for module in net.modules() if type(module) == torch.nn.Conv2d][-1]
        guided_gc = GuidedGradCam(net, last_conv)
        ggc_real = guided_gc.attribute(imgs[0], target=classes[0])
        ggc_fake = guided_gc.attribute(imgs[1], target=classes[1])

        attrs.append(ggc_real[0,0,:,:])
        attrs_names.append("ggc_real")

        attrs.append(ggc_fake[0,0,:,:])
        attrs_names.append("ggc_fake")

        net.zero_grad()
        gbp = GuidedBackprop(net)
        gbp_real = gbp.attribute(imgs[0], target=classes[0])
        gbp_fake = gbp.attribute(imgs[1], target=classes[1])
        
        attrs.append(gbp_real[0,0,:,:])
        attrs_names.append("gbp_real")

        attrs.append(gbp_fake[0,0,:,:])
        attrs_names.append("gbp_fake")

        ggc_diff_0 = gbp_real[0,0,:,:] * gc_diff_0
        ggc_diff_1 = gbp_fake[0,0,:,:] * gc_diff_1

        attrs.append(ggc_diff_0)
        attrs_names.append("ggc_diff_0")

        attrs.append(ggc_diff_1)
        attrs_names.append("ggc_diff_1")

    # IG
    if "ig" in methods:
        baseline = image_to_tensor(np.zeros(input_shape, dtype=np.float32))
        net.zero_grad()
        ig = IntegratedGradients(net)
        ig_real, delta_real = ig.attribute(imgs[0], baseline, target=classes[0], return_convergence_delta=True)
        ig_fake, delta_fake = ig.attribute(imgs[1], baseline, target=classes[1], return_convergence_delta=True)
        ig_diff_0, delta_diff = ig.attribute(imgs[0], imgs[1], target=classes[0], return_convergence_delta=True)
        ig_diff_1, delta_diff = ig.attribute(imgs[1], imgs[0], target=classes[1], return_convergence_delta=True)

        attrs.append(ig_real[0,0,:,:])
        attrs_names.append("ig_real")

        attrs.append(ig_fake[0,0,:,:])
        attrs_names.append("ig_fake")

        attrs.append(ig_diff_0[0,0,:,:])
        attrs_names.append("ig_diff_0")

        attrs.append(ig_diff_1[0,0,:,:])
        attrs_names.append("ig_diff_1")

        
    # DL
    if "dl" in methods:
        net.zero_grad()
        dl = DeepLift(net)
        dl_real = dl.attribute(imgs[0], target=classes[0])
        dl_fake = dl.attribute(imgs[1], target=classes[1])
        dl_diff_0 = dl.attribute(imgs[0], baselines=imgs[1], target=classes[0])
        dl_diff_1 = dl.attribute(imgs[1], baselines=imgs[0], target=classes[1])

        attrs.append(dl_real[0,0,:,:])
        attrs_names.append("dl_real")

        attrs.append(dl_fake[0,0,:,:])
        attrs_names.append("dl_fake")

        attrs.append(dl_diff_0[0,0,:,:])
        attrs_names.append("dl_diff_0")

        attrs.append(dl_diff_1[0,0,:,:])
        attrs_names.append("dl_diff_1")

    # INGRAD
    if "ingrad" in methods:
        net.zero_grad()
        saliency = Saliency(net)
        grads_real = saliency.attribute(imgs[0], 
                                        target=classes[0]) 
        grads_fake = saliency.attribute(imgs[1], 
                                        target=classes[1]) 

        attrs.append(grads_real[0,0,:,:])
        attrs_names.append("grads_real")

        attrs.append(grads_fake[0,0,:,:])
        attrs_names.append("grads_fake")

        net.zero_grad()
        input_x_gradient = InputXGradient(net)
        ingrad_real = input_x_gradient.attribute(imgs[0], target=classes[0])
        ingrad_fake = input_x_gradient.attribute(imgs[1], target=classes[1])

        ingrad_diff_0 = grads_fake * (imgs[0] - imgs[1])
        ingrad_diff_1 = grads_real * (imgs[1] - imgs[0])

        attrs.append(torch.abs(ingrad_real[0,0,:,:]))
        attrs_names.append("ingrad_real")

        attrs.append(torch.abs(ingrad_fake[0,0,:,:]))
        attrs_names.append("ingrad_fake")

        attrs.append(torch.abs(ingrad_diff_0[0,0,:,:]))
        attrs_names.append("ingrad_diff_0")

        attrs.append(torch.abs(ingrad_diff_1[0,0,:,:]))
        attrs_names.append("ingrad_diff_1")

    attrs = [a.detach().cpu().numpy() for a in attrs]
    attrs_norm = [a/np.max(np.abs(a)) for a in attrs]

    return attrs_norm, attrs_names
def main(args, DatasetsTypes, DataGenerationTypes, models, device):
    for m in range(len(models)):

        for x in range(len(DatasetsTypes)):
            for y in range(len(DataGenerationTypes)):

                if (DataGenerationTypes[y] == None):
                    args.DataName = DatasetsTypes[x] + "_Box"
                else:
                    args.DataName = DatasetsTypes[
                        x] + "_" + DataGenerationTypes[y]

                Training = np.load(args.data_dir + "SimulatedTrainingData_" +
                                   args.DataName + "_F_" +
                                   str(args.NumFeatures) + "_TS_" +
                                   str(args.NumTimeSteps) + ".npy")
                TrainingMetaDataset = np.load(args.data_dir +
                                              "SimulatedTrainingMetaData_" +
                                              args.DataName + "_F_" +
                                              str(args.NumFeatures) + "_TS_" +
                                              str(args.NumTimeSteps) + ".npy")
                TrainingLabel = TrainingMetaDataset[:, 0]

                Testing = np.load(args.data_dir + "SimulatedTestingData_" +
                                  args.DataName + "_F_" +
                                  str(args.NumFeatures) + "_TS_" +
                                  str(args.NumTimeSteps) + ".npy")
                TestingDataset_MetaData = np.load(args.data_dir +
                                                  "SimulatedTestingMetaData_" +
                                                  args.DataName + "_F_" +
                                                  str(args.NumFeatures) +
                                                  "_TS_" +
                                                  str(args.NumTimeSteps) +
                                                  ".npy")
                TestingLabel = TestingDataset_MetaData[:, 0]

                Training = Training.reshape(
                    Training.shape[0], Training.shape[1] * Training.shape[2])
                Testing = Testing.reshape(Testing.shape[0],
                                          Testing.shape[1] * Testing.shape[2])

                scaler = MinMaxScaler()
                scaler.fit(Training)
                Training = scaler.transform(Training)
                Testing = scaler.transform(Testing)

                TrainingRNN = Training.reshape(Training.shape[0],
                                               args.NumTimeSteps,
                                               args.NumFeatures)
                TestingRNN = Testing.reshape(Testing.shape[0],
                                             args.NumTimeSteps,
                                             args.NumFeatures)

                train_dataRNN = data_utils.TensorDataset(
                    torch.from_numpy(TrainingRNN),
                    torch.from_numpy(TrainingLabel))
                train_loaderRNN = data_utils.DataLoader(
                    train_dataRNN, batch_size=args.batch_size, shuffle=True)

                test_dataRNN = data_utils.TensorDataset(
                    torch.from_numpy(TestingRNN),
                    torch.from_numpy(TestingLabel))
                test_loaderRNN = data_utils.DataLoader(
                    test_dataRNN, batch_size=args.batch_size, shuffle=False)

                modelName = "Simulated"
                modelName += args.DataName

                saveModelName = "../Models/" + models[m] + "/" + modelName
                saveModelBestName = saveModelName + "_BEST.pkl"

                pretrained_model = torch.load(saveModelBestName,
                                              map_location=device)
                Test_Acc = checkAccuracy(test_loaderRNN, pretrained_model,
                                         args)
                print('{} {} model BestAcc {:.4f}'.format(
                    args.DataName, models[m], Test_Acc))

                if (Test_Acc >= 90):

                    if (args.GradFlag):
                        rescaledGrad = np.zeros((TestingRNN.shape))
                        Grad = Saliency(pretrained_model)

                    if (args.IGFlag):
                        rescaledIG = np.zeros((TestingRNN.shape))
                        IG = IntegratedGradients(pretrained_model)
                    if (args.DLFlag):
                        rescaledDL = np.zeros((TestingRNN.shape))
                        DL = DeepLift(pretrained_model)
                    if (args.GSFlag):
                        rescaledGS = np.zeros((TestingRNN.shape))
                        GS = GradientShap(pretrained_model)
                    if (args.DLSFlag):
                        rescaledDLS = np.zeros((TestingRNN.shape))
                        DLS = DeepLiftShap(pretrained_model)

                    if (args.SGFlag):
                        rescaledSG = np.zeros((TestingRNN.shape))
                        Grad_ = Saliency(pretrained_model)
                        SG = NoiseTunnel(Grad_)

                    if (args.ShapleySamplingFlag):
                        rescaledShapleySampling = np.zeros((TestingRNN.shape))
                        SS = ShapleyValueSampling(pretrained_model)
                    if (args.GSFlag):
                        rescaledFeaturePermutation = np.zeros(
                            (TestingRNN.shape))
                        FP = FeaturePermutation(pretrained_model)
                    if (args.FeatureAblationFlag):
                        rescaledFeatureAblation = np.zeros((TestingRNN.shape))
                        FA = FeatureAblation(pretrained_model)

                    if (args.OcclusionFlag):
                        rescaledOcclusion = np.zeros((TestingRNN.shape))
                        OS = Occlusion(pretrained_model)

                    idx = 0
                    mask = np.zeros((args.NumTimeSteps, args.NumFeatures),
                                    dtype=int)
                    for i in range(args.NumTimeSteps):
                        mask[i, :] = i

                    for i, (samples, labels) in enumerate(test_loaderRNN):

                        print('[{}/{}] {} {} model accuracy {:.2f}'\
                                .format(i,len(test_loaderRNN), models[m], args.DataName, Test_Acc))

                        input = samples.reshape(-1, args.NumTimeSteps,
                                                args.NumFeatures).to(device)
                        input = Variable(input,
                                         volatile=False,
                                         requires_grad=True)

                        batch_size = input.shape[0]
                        baseline_single = torch.from_numpy(
                            np.random.random(input.shape)).to(device)
                        baseline_multiple = torch.from_numpy(
                            np.random.random(
                                (input.shape[0] * 5, input.shape[1],
                                 input.shape[2]))).to(device)
                        inputMask = np.zeros((input.shape))
                        inputMask[:, :, :] = mask
                        inputMask = torch.from_numpy(inputMask).to(device)
                        mask_single = torch.from_numpy(mask).to(device)
                        mask_single = mask_single.reshape(
                            1, args.NumTimeSteps, args.NumFeatures).to(device)
                        labels = torch.tensor(labels.int().tolist()).to(device)

                        if (args.GradFlag):
                            attributions = Grad.attribute(input, \
                                                          target=labels)
                            rescaledGrad[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.IGFlag):
                            attributions = IG.attribute(input,  \
                                                        baselines=baseline_single, \
                                                        target=labels)
                            rescaledIG[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.DLFlag):
                            attributions = DL.attribute(input,  \
                                                        baselines=baseline_single, \
                                                        target=labels)
                            rescaledDL[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.GSFlag):

                            attributions = GS.attribute(input,  \
                                                        baselines=baseline_multiple, \
                                                        stdevs=0.09,\
                                                        target=labels)
                            rescaledGS[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.DLSFlag):

                            attributions = DLS.attribute(input,  \
                                                        baselines=baseline_multiple, \
                                                        target=labels)
                            rescaledDLS[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.SGFlag):
                            attributions = SG.attribute(input, \
                                                        target=labels)
                            rescaledSG[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.ShapleySamplingFlag):
                            attributions = SS.attribute(input, \
                                            baselines=baseline_single, \
                                            target=labels,\
                                            feature_mask=inputMask)
                            rescaledShapleySampling[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.FeaturePermutationFlag):
                            attributions = FP.attribute(input, \
                                            target=labels,
                                            perturbations_per_eval= input.shape[0],\
                                            feature_mask=mask_single)
                            rescaledFeaturePermutation[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.FeatureAblationFlag):
                            attributions = FA.attribute(input, \
                                            target=labels)
                            # perturbations_per_eval= input.shape[0],\
                            # feature_mask=mask_single)
                            rescaledFeatureAblation[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        if (args.OcclusionFlag):
                            attributions = OS.attribute(input, \
                                            sliding_window_shapes=(1,args.NumFeatures),
                                            target=labels,
                                            baselines=baseline_single)
                            rescaledOcclusion[
                                idx:idx +
                                batch_size, :, :] = Helper.givenAttGetRescaledSaliency(
                                    args, attributions)

                        idx += batch_size

                    if (args.plot):
                        index = random.randint(0, TestingRNN.shape[0] - 1)
                        plotExampleBox(TestingRNN[index, :, :],
                                       args.Saliency_Maps_graphs_dir +
                                       args.DataName + "_" + models[m] +
                                       '_sample',
                                       flip=True)

                        print("Plotting sample", index)
                        if (args.GradFlag):
                            plotExampleBox(rescaledGrad[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_Grad',
                                           greyScale=True,
                                           flip=True)

                        if (args.IGFlag):
                            plotExampleBox(rescaledIG[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_IG',
                                           greyScale=True,
                                           flip=True)

                        if (args.DLFlag):
                            plotExampleBox(rescaledDL[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_DL',
                                           greyScale=True,
                                           flip=True)

                        if (args.GSFlag):
                            plotExampleBox(rescaledGS[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_GS',
                                           greyScale=True,
                                           flip=True)

                        if (args.DLSFlag):
                            plotExampleBox(rescaledDLS[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_DLS',
                                           greyScale=True,
                                           flip=True)

                        if (args.SGFlag):
                            plotExampleBox(rescaledSG[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_SG',
                                           greyScale=True,
                                           flip=True)

                        if (args.ShapleySamplingFlag):
                            plotExampleBox(
                                rescaledShapleySampling[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_ShapleySampling',
                                greyScale=True,
                                flip=True)

                        if (args.FeaturePermutationFlag):
                            plotExampleBox(
                                rescaledFeaturePermutation[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_FeaturePermutation',
                                greyScale=True,
                                flip=True)

                        if (args.FeatureAblationFlag):
                            plotExampleBox(
                                rescaledFeatureAblation[index, :, :],
                                args.Saliency_Maps_graphs_dir + args.DataName +
                                "_" + models[m] + '_FeatureAblation',
                                greyScale=True,
                                flip=True)

                        if (args.OcclusionFlag):
                            plotExampleBox(rescaledOcclusion[index, :, :],
                                           args.Saliency_Maps_graphs_dir +
                                           args.DataName + "_" + models[m] +
                                           '_Occlusion',
                                           greyScale=True,
                                           flip=True)

                    if (args.save):
                        if (args.GradFlag):
                            print("Saving Grad", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_Grad_rescaled", rescaledGrad)

                        if (args.IGFlag):
                            print("Saving IG", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_IG_rescaled", rescaledIG)

                        if (args.DLFlag):
                            print("Saving DL", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_DL_rescaled", rescaledDL)

                        if (args.GSFlag):
                            print("Saving GS", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_GS_rescaled", rescaledGS)

                        if (args.DLSFlag):
                            print("Saving DLS", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_DLS_rescaled", rescaledDLS)

                        if (args.SGFlag):
                            print("Saving SG", modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_SG_rescaled", rescaledSG)

                        if (args.ShapleySamplingFlag):
                            print("Saving ShapleySampling",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_ShapleySampling_rescaled",
                                rescaledShapleySampling)

                        if (args.FeaturePermutationFlag):
                            print("Saving FeaturePermutation",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_FeaturePermutation_rescaled",
                                rescaledFeaturePermutation)

                        if (args.FeatureAblationFlag):
                            print("Saving FeatureAblation",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_FeatureAblation_rescaled",
                                rescaledFeatureAblation)

                        if (args.OcclusionFlag):
                            print("Saving Occlusion",
                                  modelName + "_" + models[m])
                            np.save(
                                args.Saliency_dir + modelName + "_" +
                                models[m] + "_Occlusion_rescaled",
                                rescaledOcclusion)

                else:
                    logging.basicConfig(filename=args.log_file,
                                        level=logging.DEBUG)

                    logging.debug('{} {} model BestAcc {:.4f}'.format(
                        args.DataName, models[m], Test_Acc))

                    if not os.path.exists(args.ignore_list):
                        with open(args.ignore_list, 'w') as fp:
                            fp.write(args.DataName + '_' + models[m] + '\n')

                    else:
                        with open(args.ignore_list, "a") as fp:
                            fp.write(args.DataName + '_' + models[m] + '\n')
示例#14
0
文件: test.py 项目: vinnamkim/captum
scores = gcd.attribute(input,
                       target=labels[ind].item(),
                       kernel_type='gaussian',
                       kernel_size=7,
                       kernel_sigma=1.0,
                       method='fro',
                       sampling_method='min',
                       num_samples=1000,
                       sample_std=1.0)

# %% [markdown]
# Computes gradients with respect to class `ind` and transposes them for visualization purposes.

# %%
saliency = Saliency(net)
grads = saliency.attribute(input, target=labels[ind].item())
grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))

# %% [markdown]
# Applies integrated gradients attribution algorithm on test image. Integrated Gradients computes the integral of the gradients of the output prediction for the class index `ind` with respect to the input image pixels. More details about integrated gradients can be found in the original paper: https://arxiv.org/abs/1703.01365

# %%
ig = IntegratedBlurredGradients(net)
attr_ig, delta = attribute_image_features(ig,
                                          input,
                                          n_steps=50,
                                          return_convergence_delta=True)
print('Approximation delta: ', abs(delta))
from captum.attr._core.integrated_blurred_gradients import GaussianFilter
gf = GaussianFilter(2.0)
ig = IntegratedGradients(net)