Ejemplo n.º 1
0
# imports
import model_io
import data_io
import render

import importlib.util as imp
import numpy
import numpy as np
if imp.find_spec("cupy"):  #use cupy for GPU support if available
    import cupy
    import cupy as np
na = np.newaxis
# end of imports

nn = model_io.read('../models/MNIST/LeNet-5.nn')  # read model
X = data_io.read('../data/MNIST/test_images.npy')[
    na, 0, :]  # load first MNIST test image
X = X / 127.5 - 1  # normalized data to range [-1 1]

Ypred = nn.forward(X)  # forward pass through network
R = nn.lrp(Ypred)  # lrp to explain prediction of X

if not np == numpy:  # np=cupy
    X = np.asnumpy(X)
    R = np.asnumpy(R)

# render rgb images and save as image
digit = render.digit_to_rgb(X)
hm = render.hm_to_rgb(R, X)  # render heatmap R, use X as outline
render.save_image([digit, hm], '../2nd_py.png')
Ejemplo n.º 2
0
\begin{Verbatim}[frame=single, fontsize=\small]
# imports
import model_io
import data_io
import render

import numpy as np
na = np.newaxis
# end of imports

# read model and first MNIST test image
nn = model_io.read(<model_path>) 
X = data_io.read(<data_path>)[na,0,:]
# normalized data to range [-1 1]
X = X / 127.5 - 1

# forward pass through network
Ypred = nn.forward(X)
# lrp to explain prediction of X
R = nn.lrp(Ypred)

# render rgb images and save as image
digit = render.digit_to_rgb(X)
# render heatmap R, use X as outline
hm = render.hm_to_rgb(R,X) 
render.save_image([digit,hm],<i_path>)
\end{Verbatim}
Ejemplo n.º 3
0
    yselect = 3
    yselect = (np.arange(Y.shape[1])[na,:] == yselect)*1.
    R = nn.lrp(ypred*yselect) #compute first layer relvance for an arbitrarily selected class
    '''

    #undo input normalization for digit drawing. get it back to range [0,1] per pixel
    x = (x + 1.) / 2.

    if not np == numpy:  # np=cupy
        x = np.asnumpy(x)
        R = np.asnumpy(R)

    #render input and heatmap as rgb images
    digit = render.digit_to_rgb(x, scaling=3)
    hm = render.hm_to_rgb(R, X=x, scaling=3, sigma=2)
    digit_hm = render.save_image([digit, hm], '../heatmap.png')
    data_io.write(R, '../heatmap.npy')

    #display the image as written to file
    plt.imshow(digit_hm, interpolation='none')
    plt.axis('off')
    plt.show()

#note that modules.Sequential allows for batch processing inputs
if True:
    N = 256
    t_start = time.time()
    x = X[:N, ...]
    y = nn.forward(x)
    R = nn.lrp(y)
    data_io.write(R, '../Rbatch.npy')
Ejemplo n.º 4
0
    #R = nn.lrp(ypred,'alphabeta',2)    #as Eq(60) from DOI: 10.1371/journal.pone.0130140
    
    '''
    R = nn.lrp(Y[na,i]) #compute first layer relevance according to the true class label
    '''
    
    '''
	yselect = 3
    yselect = (np.arange(Y.shape[1])[na,:] == yselect)*1. 
    R = nn.lrp(yselect) #compute first layer relvance for an arbitrarily selected class 
    '''
    
    #render input and heatmap as rgb images
    digit = render.digit_to_rgb(x, scaling = 3)
    hm = render.hm_to_rgb(R, X = x, scaling = 3, sigma = 2)
    digit_hm = render.save_image([digit,hm],'../heatmap.png')
    data_io.write(R,'../heatmap.npy')
    
    #display the image as written to file
    plt.imshow(digit_hm, interpolation = 'none')
    plt.axis('off')
    plt.show()


#note that modules.Sequential allows for batch processing inputs
'''
x = X[:10,:]
y = nn.forward(x)
R = nn.lrp(y)
data_io.write(R,'../Rbatch.npy')
'''
Ejemplo n.º 5
0
def main(args):
    if args.isShow or args.isTexture:
        import cv2
        from utils.cv_plot import plot_kpt, plot_vertices, plot_pose_box

    # ---- init PRN
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # GPU number, -1 for CPU
    prn = PRN(is_dlib = args.isDlib)

    # ------------- load data
    image_folder = args.inputDir
    save_folder = args.outputDir
    if not os.path.exists(save_folder):
        os.mkdir(save_folder)

    # types = ('*.jpg', '*.png')
    # image_path_list= []
    # for files in types:
    #     image_path_list.extend(glob(os.path.join(image_folder, files)))
    # total_num = len(image_path_list)

    for dir, dirs, files in sorted(os.walk(image_folder)):
        for file in files:
            image_path = os.path.join(dir, file)
            dir = dir.replace("\\", "/")
            new_dir = dir.replace(image_folder, save_folder)
            if not os.path.isdir(new_dir):
                os.mkdir(new_dir)

            name = image_path.replace(image_folder, save_folder)
            print('data path:', name)

            # read image
            image = imread(image_path)
            [h, w, c] = image.shape
            if c>3:
                image = image[:,:,:3]

            # the core: regress position map
            if args.isDlib:
                max_size = max(image.shape[0], image.shape[1])
                if max_size> 1000:
                    image = rescale(image, 1000./max_size)
                    image = (image*255).astype(np.uint8)
                pos = prn.process(image) # use dlib to detect face
            else:
                # if image.shape[0] == image.shape[1]:
                #     image = resize(image, (256,256))
                #     pos = prn.net_forward(image/255.) # input image has been cropped to 256x256
                # else:
                box = np.array([0, image.shape[1]-1, 0, image.shape[0]-1]) # cropped with bounding box
                pos = prn.process(image, box)

            image = image/255.
            if pos is None:
                continue

            if args.is3d or args.isMat or args.isPose or args.isShow:
                # 3D vertices
                vertices = prn.get_vertices(pos)
                if args.isFront:
                    save_vertices = frontalize(vertices)
                else:
                    save_vertices = vertices.copy()
                save_vertices[:,1] = h - 1 - save_vertices[:,1]

            if args.isImage:
                imsave(name, image)

            if args.is3d:
                # corresponding colors
                colors = prn.get_colors(image, vertices)

                if args.isTexture:
                    if args.texture_size != 256:
                        pos_interpolated = resize(pos, (args.texture_size, args.texture_size), preserve_range = True)
                    else:
                        pos_interpolated = pos.copy()
                    texture = cv2.remap(image, pos_interpolated[:,:,:2].astype(np.float32), None, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT,borderValue=(0))
                    if args.isMask:
                        vertices_vis = get_visibility(vertices, prn.triangles, h, w)
                        uv_mask = get_uv_mask(vertices_vis, prn.triangles, prn.uv_coords, h, w, prn.resolution_op)
                        uv_mask = resize(uv_mask, (args.texture_size, args.texture_size), preserve_range = True)
                        texture = texture*uv_mask[:,:,np.newaxis]
                    write_obj_with_texture(name.replace('.jpg', '.obj'), save_vertices, prn.triangles, texture, prn.uv_coords/prn.resolution_op)#save 3d face with texture(can open with meshlab)
                else:
                    write_obj_with_colors(name.replace('.jpg', '.obj'), save_vertices, prn.triangles, colors) #save 3d face(can open with meshlab)
                
                filepath = name.replace('.jpg', '.obj')
                filepath = filepath.replace("\\", "/")
                print('filepath:', filepath)
                new_dir = dir.replace(args.inputDir, args.renderDir)
                # print(new_dir + '/' + file)
                if not os.path.isdir(new_dir):
                    os.mkdir(new_dir)

                color_image1, _ = render_scene(filepath, 4.0, 0.0, 3.0)
                color_image2, _ = render_scene(filepath, 4.0, np.pi / 18.0, 3.0)
                color_image3, _ = render_scene(filepath, 4.0, np.pi / 9.0, 3.0)

                if color_image1 is None or color_image2 is None:
                    continue

                new_path = filepath.replace(args.outputDir, args.renderDir)
                # print('new_path:', new_path)
                save_image(new_path, '_40_', color_image1)
                save_image(new_path, '_50_', color_image2)
                save_image(new_path, '_60_', color_image3)

                os.remove(name.replace('.jpg', '.obj'))

            if args.isDepth:
                depth_image = get_depth_image(vertices, prn.triangles, h, w, True)
                depth = get_depth_image(vertices, prn.triangles, h, w)
                imsave(os.path.join(name.replace('.jpg', '_depth.jpg')), depth_image)
                sio.savemat(name.replace('.jpg', '_depth.mat'), {'depth': depth})

            if args.isMat:
                sio.savemat(name.replace('.jpg', '_mesh.mat'),
                            {'vertices': vertices, 'colors': colors, 'triangles': prn.triangles})

            if args.isKpt or args.isShow:
                # get landmarks
                kpt = prn.get_landmarks(pos)
                np.savetxt(name.replace('.jpg', '_kpt.txt'), kpt)

            if args.isPose or args.isShow:
                # estimate pose
                camera_matrix, pose = estimate_pose(vertices)

                np.savetxt(name.replace('.jpg', '_pose.txt'), pose)
                np.savetxt(name.replace('.jpg', '_camera_matrix.txt'), camera_matrix)

                np.savetxt(name.replace('.jpg', '_pose.txt'), pose)

            if args.isShow:
                # ---------- Plot
                image_pose = plot_pose_box(image, camera_matrix, kpt)
                cv2.imshow('sparse alignment', plot_kpt(image, kpt))
                cv2.imshow('dense alignment', plot_vertices(image, vertices))
                cv2.imshow('pose', plot_pose_box(image, camera_matrix, kpt))
                cv2.waitKey(0)
Ejemplo n.º 6
0
data_train = unpickle(occdir.format('train'))
Xtrain = np.array(data_train)
Ytrain = unpickle(occdir_y.format('train'))
Ytrain = np.array(Ytrain)
X = Xtrain
Y = Ytrain
occdir = get_savedir() + '{}_{}_{}.pickle'.format('{}', 'LRP', '0.1')
# occdir_y = get_savedir() + '{}_{}_{}_{}.pickle'.format('{}', 'normal', '0.1','label')

data_train = unpickle(occdir.format('train'))
XR = np.array(data_train)
# Ytrain = unpickle(occdir_y.format('train'))
I = Y[:,0].astype(int)
Y = np.zeros([X.shape[0],np.unique(Y).size])
Y[np.arange(Y.shape[0]),I] = 1
nn = model_io.read('../models/MNIST/LeNet-5.nn')

for i in I[:10]:
    x = X[i:i+1,...]
    r = XR[i:i+1,...]
    ypred = nn.forward(x)
    print('True Class:     ', np.argmax(Y[i]))
    print('Predicted Class:', np.argmax(ypred),'\n')
    digit = render.digit_to_rgb(x, scaling = 3)
    Rdigit = render.digit_to_rgb(r, scaling = 3)
    hm = render.hm_to_rgb(r, X = x, scaling = 3, sigma = 2)
    digit_hm = render.save_image([digit,Rdigit],'../re_heatmap.png')
    plt.imshow(digit_hm, interpolation = 'none')
    plt.axis('off')
    plt.show()
Ejemplo n.º 7
0
    Rinit = ypred*mask
    print("Lrp R shape {} : ".format(Rinit.shape))
    #compute first layer relevance according to prediction
    #R = nn.lrp(Rinit)                   #as Eq(56) from DOI: 10.1371/journal.pone.0130140
    R = nn.lrp(Rinit,'epsilon',1.)

    R = R.sum(axis=3)
    xs = ((x+1.)/2.).sum(axis=3)

    if not np == numpy: 
        xs = np.asnumpy(xs)
        R = np.asnumpy(R)

    digit = render.digit_to_rgb(xs, scaling = 3)
    hm = render.hm_to_rgb(R, X = xs, scaling = 3, sigma = 2)
    digit_hm = render.save_image([digit,hm],'../heatmap.png')
    data_io.write(R,'../heatmap.npy')
    data_io.write(xs,'../xs.npy')
    print(xs.shape)
    y = xs
    a = np.load('../r_array/convolution.npy')
    a = np.reshape(a,[a.shape[1]*a.shape[2],1])
    b = np.load('../r_array/rect.npy')    
    b = np.pad(b,((0,0),(2,2),(2,2),(0,0)))
    b = np.reshape(b,[b.shape[1]*b.shape[2],b.shape[0]*b.shape[3]])
    c = np.load('../r_array/sumpoll.npy')
    c = np.pad(c,((0,0),(2,2),(2,2),(0,0)))
    c = np.reshape(c,[c.shape[1]*c.shape[2],c.shape[3]])
    
    new_b = np.hstack((b, c))
    new = np.hstack((a, new_b))
Ejemplo n.º 8
0
@author: Sebastian Bach
@maintainer: Sebastian Bach
@contact: [email protected]
@date: 21.09.2015
@version: 1.0
@copyright: Copyright (c)  2015, Sebastian Bach, Alexander Binder, Gregoire Montavon, Klaus-Robert Mueller
@license : BSD-2-Clause
'''


# imports
import model_io
import data_io
import render

import numpy as np
na = np.newaxis
# end of imports

nn = model_io.read('../models/MNIST/long-rect.nn') # read model
X = data_io.read('../data/MNIST/test_images.npy')[na,0,:] # load first MNIST test image
X = X / 127.5 - 1 # normalized data to range [-1 1]

Ypred = nn.forward(X) # forward pass through network
R = nn.lrp(Ypred) # lrp to explain prediction of X

# render rgb images and save as image
digit = render.digit_to_rgb(X)
hm = render.hm_to_rgb(R, X) # render heatmap R, use X as outline
render.save_image([digit, hm], './hm_py.png')
Ejemplo n.º 9
0
def occlude_dataset(DNN, attribution, percentiles, test=False, keep=False, random=False, batch_size= 64, savedir=''):
    print("Condition of test : {}".format(test))
    if test:
        Xs = Xtest
        ys = Ytest
    else:
        Xs = Xtrain
        ys = Ytrain
    
    print("initial batch_size is : {}".format(batch_size))
    total_batch = math.ceil(len(Xs) / batch_size)
    print("batch size is :{}".format(total_batch))
    hmaps = []
    data = []
    label = []
    hmaps_00 = []
    hmaps_01 = []
    hmaps_02 = []
    hmaps_03 = []
    hmaps_04 = []
    hmaps_05 = []
    hmaps_06 = []
    hmaps_07 = []
    hmaps_08 = []
    hmaps_09 = []
    hmaps_10 = []
    data_00 = []
    data_01 = []
    data_02 = []
    data_03 = []
    data_04 = []
    data_05 = []
    data_06 = []
    data_07 = []
    data_08 = []
    data_09 = []
    data_10 = []
    label_00 = []
    label_01 = []
    label_02 = []
    label_03 = []
    label_04 = []
    label_05 = []
    label_06 = []
    label_07 = []
    label_08 = []
    label_09 = []
    label_10 = []
    for i in tqdm(range(total_batch)):
        if 'LRP' in attribution:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            ypred = DNN.forward(x)
            m = np.zeros_like(ypred)
            m[:,np.argmax(ypred)] = 1
            Rinit = ypred*m
            Rinit.astype(np.float)
            R = DNN.lrp(Rinit,'epsilon',1.)
            R = R.sum(axis=3)
            if not np == numpy:
                R = np.asnumpy(R)
            
            attrs = R
            data = x
            
            attrs = scaling(attrs)
            attrs *= 255
            attrs = attrs.astype(np.uint8)
            attrs = scale(attrs)



        elif 'proposed_method' in attribution:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            ypred = DNN.forward(x)

            m = np.zeros_like(ypred)
            m[:,np.argmax(ypred)] = 1
            Rinit = ypred*m
            Rinit.astype(np.float)
            R = DNN.lrp(Rinit,'epsilon',1.)
            R = R.sum(axis=3)
            if not np == numpy: 
                xs = np.asnumpy(x)
                R = np.asnumpy(R)
            xs = x
            tar = xs
            a = np.load('../r_array/convolution.npy')
            a = np.reshape(a,[a.shape[1]*a.shape[2],1])
            b = np.load('../r_array/rect.npy')    
            b = np.pad(b,((0,0),(2,2),(2,2),(0,0)))
            b = np.reshape(b,[b.shape[1]*b.shape[2],b.shape[0]*b.shape[3]])
            c = np.load('../r_array/sumpoll.npy')
            c = np.pad(c,((0,0),(2,2),(2,2),(0,0)))
            c = np.reshape(c,[c.shape[1]*c.shape[2],c.shape[3]])
            new_b = np.hstack((b, c))
            new = np.hstack((a, new_b))
            tar = np.reshape(tar, [tar.shape[0]*tar.shape[1]*tar.shape[2]])
            y_tran = tar.transpose()
            new = sm.add_constant(new)

            model = sm.GLSAR(y_tran, new, rho = 2)
            result = model.iterative_fit()
            find = result.resid
            check = np.reshape(find,[1,32,32])
            
            attrs = check
            data = x
            
            attrs = scaling(attrs)
            attrs *= 255
            attrs = attrs.astype(np.uint8)
            attrs = scale(attrs)

            
        else:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            if not np == numpy:
                xs = np.asnumpy(x)
            xs = x
            
            attrs = xs
            data = x
            
#             attrs = scaling(attrs)
            attrs *= 255
            attrs = attrs.astype(np.uint8)
            attrs = scale(attrs)


        for percent in tqdm(percentiles):
            batch_attrs = attrs
            data = data
            if attribution == 'normal':
                if percent == 0.0:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_00.append(occluded_images)
                    data_00.append(data)
                    label_00.append(y)
                elif percent == 0.1:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_01.append(occluded_images)
                    data_01.append(data)
                    label_01.append(y)
                elif percent == 0.2:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_02.append(occluded_images)
                    data_02.append(data)
                    label_02.append(y)
                elif percent == 0.3:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_03.append(occluded_images)
                    data_03.append(data)
                    label_03.append(y)
                elif percent == 0.4:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_04.append(occluded_images)
                    data_04.append(data)
                    label_04.append(y)
                elif percent == 0.5:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_05.append(occluded_images)
                    data_05.append(data)
                    label_05.append(y)
                elif percent == 0.6:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_06.append(occluded_images)
                    data_06.append(data)
                    label_06.append(y)
                elif percent == 0.7:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_07.append(occluded_images)
                    data_07.append(data)
                    label_07.append(y)
                elif percent == 0.8:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_08.append(occluded_images)
                    data_08.append(data)
                    label_08.append(y)
                elif percent == 0.9:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_09.append(occluded_images)
                    data_09.append(data)
                    label_09.append(y)
                elif percent == 1:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_10.append(occluded_images)
                    data_10.append(data)
                    label_10.append(y)
                else:
                    raise ValueError("attribution?")
            else:
                if percent == 0.0:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_00.append(occluded_images)
                    data_00.append(data)
                    label_00.append(y)
                elif percent == 0.1:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_01.append(occluded_images)
                    data_01.append(data)
                    label_01.append(y)
                elif percent == 0.2:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_02.append(occluded_images)
                    data_02.append(data)
                    label_02.append(y)
                elif percent == 0.3:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_03.append(occluded_images)
                    data_03.append(data)
                    label_03.append(y)
                elif percent == 0.4:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_04.append(occluded_images)
                    data_04.append(x)
                    label_04.append(y)
                elif percent == 0.5:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_05.append(occluded_images)
                    data_05.append(x)
                    label_05.append(y)
                elif percent == 0.6:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_06.append(occluded_images)
                    data_06.append(x)
                    label_06.append(y)
                elif percent == 0.7:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_07.append(occluded_images)
                    data_07.append(x)
                    label_07.append(y)
                elif percent == 0.8:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_08.append(occluded_images)
                    data_08.append(x)
                    label_08.append(y)
                elif percent == 0.9:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_09.append(occluded_images)
                    data_09.append(x)
                    label_09.append(y)
                elif percent == 1:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_10.append(occluded_images)
                    data_10.append(x)
                    label_10.append(y)
                else:
                    raise ValueError("Error : {}".format(percent))

        
        
        
#         hmaps.append(attrs)
#         data.append(x)
#         label.append(y)
#         print("print final : {}".format(len(hmaps)))
    print("Interpretation is done, concatenate...")
    hmaps_00 = np.concatenate(hmaps_00, axis=0)
    hmaps_01 = np.concatenate(hmaps_01, axis=0)
    hmaps_02 = np.concatenate(hmaps_02, axis=0)
    hmaps_03 = np.concatenate(hmaps_03, axis=0)
    hmaps_04 = np.concatenate(hmaps_04, axis=0)
    hmaps_05 = np.concatenate(hmaps_05, axis=0)
    hmaps_06 = np.concatenate(hmaps_06, axis=0)
    hmaps_07 = np.concatenate(hmaps_07, axis=0)
    hmaps_08 = np.concatenate(hmaps_08, axis=0)
    hmaps_09 = np.concatenate(hmaps_09, axis=0)
    hmaps_10 = np.concatenate(hmaps_10, axis=0)
    data_00 = np.concatenate(data_00, axis = 0)
    data_01 = np.concatenate(data_01, axis = 0)
    data_02 = np.concatenate(data_02, axis = 0)
    data_03 = np.concatenate(data_03, axis = 0)
    data_04 = np.concatenate(data_04, axis = 0)
    data_05 = np.concatenate(data_05, axis = 0)
    data_06 = np.concatenate(data_06, axis = 0)
    data_07 = np.concatenate(data_07, axis = 0)
    data_08 = np.concatenate(data_08, axis = 0)
    data_09 = np.concatenate(data_09, axis = 0)
    data_10 = np.concatenate(data_10, axis = 0)
    label_00 = np.concatenate(label_00, axis = 0)
    label_01 = np.concatenate(label_01, axis = 0)
    label_02 = np.concatenate(label_02, axis = 0)
    label_03 = np.concatenate(label_03, axis = 0)
    label_04 = np.concatenate(label_04, axis = 0)
    label_05 = np.concatenate(label_05, axis = 0)
    label_06 = np.concatenate(label_06, axis = 0)
    label_07 = np.concatenate(label_07, axis = 0)
    label_08 = np.concatenate(label_08, axis = 0)
    label_09 = np.concatenate(label_09, axis = 0)
    label_10 = np.concatenate(label_10, axis = 0)
    
#     hmaps = np.concatenate(hmaps, axis=0)
#     data = np.concatenate(data, axis = 0)
    
    print("concatenate is done...")
#     print("print final : {}".format(hmaps.shape))
#     percentiles = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
    for percent in tqdm(percentiles):
        if percent == 0.0:
            save(hmaps_00, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_00, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.1:
            save(hmaps_01, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_01, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.2:
            save(hmaps_02, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_02, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.3:
            save(hmaps_03, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_03, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.4:
            save(hmaps_04, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_04, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.5:
            save(hmaps_05, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_05, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.6:
            save(hmaps_06, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_06, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.7:
            save(hmaps_07, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_07, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.8:
            save(hmaps_08, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_08, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.9:
            save(hmaps_09, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_09, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 1:
            save(hmaps_10, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_10, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        else:
            print("error")
    del hmaps_00,hmaps_01,hmaps_02,hmaps_03,hmaps_04,hmaps_05,hmaps_06,hmaps_07,hmaps_08,hmaps_09,hmaps_10,data_00,data_01,data_02,data_03,data_04,data_05,data_06,data_07,data_08,data_09,data_10,label_00,label_01,label_02,label_03,label_04,label_05,label_06,label_07,label_08,label_09,label_10