Ejemplo n.º 1
0
\begin{Verbatim}[frame=single, fontsize=\small]
# imports
import model_io
import data_io
import render

import numpy as np
na = np.newaxis
# end of imports

# read model and first MNIST test image
nn = model_io.read(<model_path>) 
X = data_io.read(<data_path>)[na,0,:]
# normalized data to range [-1 1]
X = X / 127.5 - 1

# forward pass through network
Ypred = nn.forward(X)
# lrp to explain prediction of X
R = nn.lrp(Ypred)

# render rgb images and save as image
digit = render.digit_to_rgb(X)
# render heatmap R, use X as outline
hm = render.hm_to_rgb(R,X) 
render.save_image([digit,hm],<i_path>)
\end{Verbatim}
Ejemplo n.º 2
0
    #R = nn.lrp(ypred*Y[na,i]) #compute first layer relevance according to the true class label
    '''
    yselect = 3
    yselect = (np.arange(Y.shape[1])[na,:] == yselect)*1.
    R = nn.lrp(ypred*yselect) #compute first layer relvance for an arbitrarily selected class
    '''

    #undo input normalization for digit drawing. get it back to range [0,1] per pixel
    x = (x + 1.) / 2.

    if not np == numpy:  # np=cupy
        x = np.asnumpy(x)
        R = np.asnumpy(R)

    #render input and heatmap as rgb images
    digit = render.digit_to_rgb(x, scaling=3)
    hm = render.hm_to_rgb(R, X=x, scaling=3, sigma=2)
    digit_hm = render.save_image([digit, hm], '../heatmap.png')
    data_io.write(R, '../heatmap.npy')

    #display the image as written to file
    plt.imshow(digit_hm, interpolation='none')
    plt.axis('off')
    plt.show()

#note that modules.Sequential allows for batch processing inputs
if True:
    N = 256
    t_start = time.time()
    x = X[:N, ...]
    y = nn.forward(x)
Ejemplo n.º 3
0
    R = nn.lrp(ypred)                   #as Eq(56) from DOI: 10.1371/journal.pone.0130140
    #R = nn.lrp(ypred,'epsilon',100)    #as Eq(58) from DOI: 10.1371/journal.pone.0130140
    #R = nn.lrp(ypred,'alphabeta',2)    #as Eq(60) from DOI: 10.1371/journal.pone.0130140
    
    '''
    R = nn.lrp(Y[na,i]) #compute first layer relevance according to the true class label
    '''
    
    '''
	yselect = 3
    yselect = (np.arange(Y.shape[1])[na,:] == yselect)*1. 
    R = nn.lrp(yselect) #compute first layer relvance for an arbitrarily selected class 
    '''
    
    #render input and heatmap as rgb images
    digit = render.digit_to_rgb(x, scaling = 3)
    hm = render.hm_to_rgb(R, X = x, scaling = 3, sigma = 2)
    digit_hm = render.save_image([digit,hm],'../heatmap.png')
    data_io.write(R,'../heatmap.npy')
    
    #display the image as written to file
    plt.imshow(digit_hm, interpolation = 'none')
    plt.axis('off')
    plt.show()


#note that modules.Sequential allows for batch processing inputs
'''
x = X[:10,:]
y = nn.forward(x)
R = nn.lrp(y)
Ejemplo n.º 4
0
def occlude_dataset(DNN,
                    attribution,
                    percentiles,
                    test=False,
                    keep=False,
                    random=False,
                    batch_size=128,
                    savedir=''):
    print("Condition of test : {}".format(test))
    if test:
        Xs = Xtest
        ys = Ytest
    else:
        Xs = Xtrain
        ys = Ytrain

    print("initial batch_size is : {}".format(batch_size))
    total_batch = math.ceil(len(Xs) / batch_size)
    print("batch size is :{}".format(total_batch))
    hmaps = []
    data = []
    label = []
    for i in tqdm(range(total_batch)):

        #         batch_xs = Xs[i*batch_size:(i+1)*batch_size]
        # #         batch_xs_scaled = scale(batch_xs)
        if 'LRP' in attribution:
            #                 for t in It[:10]:
            x = Xs[i:i + 1, ...]
            y = ys[i:i + 1, ...]
            ypred = DNN.forward(x)
            #                 print('True Class:     ', np.argmax(ys[i]))
            #                 print('Predicted Class:', np.argmax(ypred),'\n')
            m = np.zeros_like(ypred)
            m[:, np.argmax(ypred)] = 1
            Rinit = ypred * m
            Rinit.astype(np.float)
            R = DNN.lrp(Rinit, 'epsilon', 1.)
            R = R.sum(axis=3)
            if not np == numpy:
                R = np.asnumpy(R)
            if test:
                LRP_test = render.digit_to_rgb(R, scaling=3)
            attrs = R
#                 attrs = np.sum(np.where(attrs > 0, attrs, 0.0), axis=-1)
#                 print("print lrp : {}".format(attrs.shape))
        elif 'proposed_method' in attribution:
            #                 for t in It[:10]:
            x = Xs[i:i + 1, ...]
            y = ys[i:i + 1, ...]
            ypred = DNN.forward(x)
            #                 print('True Class:     ', np.argmax(ys[i]))
            #                 print('Predicted Class:', np.argmax(ypred),'\n')
            m = np.zeros_like(ypred)
            m[:, np.argmax(ypred)] = 1
            Rinit = ypred * m
            Rinit.astype(np.float)
            R = DNN.lrp(Rinit, 'epsilon', 1.)
            R = R.sum(axis=3)
            if not np == numpy:
                xs = np.asnumpy(x)
                R = np.asnumpy(R)
            xs = x
            tar = xs
            a = np.load('../r_array/convolution.npy')
            a = np.reshape(a, [a.shape[1] * a.shape[2], 1])
            b = np.load('../r_array/rect.npy')
            b = np.pad(b, ((0, 0), (2, 2), (2, 2), (0, 0)))
            b = np.reshape(b,
                           [b.shape[1] * b.shape[2], b.shape[0] * b.shape[3]])
            c = np.load('../r_array/sumpoll.npy')
            c = np.pad(c, ((0, 0), (2, 2), (2, 2), (0, 0)))
            c = np.reshape(c, [c.shape[1] * c.shape[2], c.shape[3]])
            new_b = np.hstack((b, c))
            new = np.hstack((a, new_b))
            tar = np.reshape(tar, [tar.shape[0] * tar.shape[1] * tar.shape[2]])
            y_tran = tar.transpose()
            new = sm.add_constant(new)
            #             print(new.shape)
            #             print(y_tran.shape)
            model = sm.GLSAR(y_tran, new, rho=2)
            result = model.iterative_fit(maxiter=30)
            find = result.resid
            check = np.reshape(find, [1, 32, 32])
            if test:
                proposed_test = render.digit_to_rgb(check, scaling=3)
            attrs = check
#                 attrs = np.sum(np.where(attrs > 0, attrs, 0.0), axis=-1)
#                 print("print propose : {}".format(attrs.shape))
        else:
            x = Xs[i:i + 1, ...]
            y = ys[i:i + 1, ...]
            if not np == numpy:
                xs = np.asnumpy(x)
            xs = x
            if test:
                digit = render.digit_to_rgb(xs, scaling=3)
            attrs = xs
#                 attrs = np.sum(np.where(attrs > 0, attrs, 0.0), axis=-1)
#                 print("print normal : {}".format(attrs.shape))
        attrs += np.random.normal(scale=1e-4, size=attrs.shape)
        #         print("print random normal : {}".format(attrs.shape))
        hmaps.append(attrs)
        data.append(x)
        label.append(y)


#         print("print final : {}".format(len(hmaps)))
    print("Interpretation is done, concatenate...")
    hmaps = np.concatenate(hmaps, axis=0)
    data = np.concatenate(data, axis=0)

    print("concatenate is done...")
    print("print final : {}".format(hmaps.shape))
    #     percentiles = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
    for percent in tqdm(percentiles):

        #         dataset = []
        #         y_target = []
        #         for i in It[:10]:

        #             batch_xs, batch_ys = Xs[i*batch_size:(i+1)*batch_size], ys[i*batch_size:(i+1)*batch_size]
        #             x = Xs[i:i+1,...]
        #             y = ys[i:i+1,...]
        # batch_attrs = hmaps[i:i+1,...]
        batch_attrs = hmaps
        occluded_images = remove(data, batch_attrs, percent, keep)

        #             dataset.append(scale(occluded_images))
        #             y_target.append(y)
        #             del occluded_images
        #
        print("save start")
        #         print("dataset shape : {}".format(dataset))
        print("Save directory is {}".format(savedir))
        save(
            occluded_images, savedir + '{}_{}_{}.pickle'.format(
                'test' if test else 'train', attribution, percent))

        #         save(np.concatenate(dataset, axis=0), savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
        #         save(np.concatenate(y_target, axis=0), savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
        save(
            np.concatenate(label, axis=0),
            savedir + '{}_{}_{}_{}.pickle'.format(
                'test' if test else 'train', attribution, percent, 'label'))
        print("Occlude image {} percentile...".format(percent))
Ejemplo n.º 5
0
    print("mask shape is : {}".format(mask.shape))
    # 초기 Relevance Score 지정하고
    Rinit = ypred*mask
    print("Lrp R shape {} : ".format(Rinit.shape))
    #compute first layer relevance according to prediction
    #R = nn.lrp(Rinit)                   #as Eq(56) from DOI: 10.1371/journal.pone.0130140
    R = nn.lrp(Rinit,'epsilon',1.)

    R = R.sum(axis=3)
    xs = ((x+1.)/2.).sum(axis=3)

    if not np == numpy: 
        xs = np.asnumpy(xs)
        R = np.asnumpy(R)

    digit = render.digit_to_rgb(xs, scaling = 3)
    hm = render.hm_to_rgb(R, X = xs, scaling = 3, sigma = 2)
    digit_hm = render.save_image([digit,hm],'../heatmap.png')
    data_io.write(R,'../heatmap.npy')
    data_io.write(xs,'../xs.npy')
    print(xs.shape)
    y = xs
    a = np.load('../r_array/convolution.npy')
    a = np.reshape(a,[a.shape[1]*a.shape[2],1])
    b = np.load('../r_array/rect.npy')    
    b = np.pad(b,((0,0),(2,2),(2,2),(0,0)))
    b = np.reshape(b,[b.shape[1]*b.shape[2],b.shape[0]*b.shape[3]])
    c = np.load('../r_array/sumpoll.npy')
    c = np.pad(c,((0,0),(2,2),(2,2),(0,0)))
    c = np.reshape(c,[c.shape[1]*c.shape[2],c.shape[3]])
    
Ejemplo n.º 6
0
def occlude_dataset(DNN, attribution, percentiles, test=False, keep=False, random=False, batch_size= 64, savedir=''):
    print("Condition of test : {}".format(test))
    if test:
        Xs = Xtest
        ys = Ytest
    else:
        Xs = Xtrain
        ys = Ytrain
    
    print("initial batch_size is : {}".format(batch_size))
    total_batch = math.ceil(len(Xs) / batch_size)
    print("batch size is :{}".format(total_batch))
    hmaps = []
    data = []
    label = []
    hmaps_00 = []
    hmaps_01 = []
    hmaps_02 = []
    hmaps_03 = []
    hmaps_04 = []
    hmaps_05 = []
    hmaps_06 = []
    hmaps_07 = []
    hmaps_08 = []
    hmaps_09 = []
    hmaps_10 = []
    data_00 = []
    data_01 = []
    data_02 = []
    data_03 = []
    data_04 = []
    data_05 = []
    data_06 = []
    data_07 = []
    data_08 = []
    data_09 = []
    data_10 = []
    label_00 = []
    label_01 = []
    label_02 = []
    label_03 = []
    label_04 = []
    label_05 = []
    label_06 = []
    label_07 = []
    label_08 = []
    label_09 = []
    label_10 = []
    for i in tqdm(range(total_batch)):
        if 'LRP' in attribution:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            ypred = DNN.forward(x)
            m = np.zeros_like(ypred)
            m[:,np.argmax(ypred)] = 1
            Rinit = ypred*m
            Rinit.astype(np.float)
            R = DNN.lrp(Rinit,'epsilon',1.)
            R = R.sum(axis=3)
            if not np == numpy:
                R = np.asnumpy(R)
            
            attrs = R
            data = x
            
            attrs = scaling(attrs)
            attrs *= 255
            attrs = attrs.astype(np.uint8)
            attrs = scale(attrs)



        elif 'proposed_method' in attribution:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            ypred = DNN.forward(x)

            m = np.zeros_like(ypred)
            m[:,np.argmax(ypred)] = 1
            Rinit = ypred*m
            Rinit.astype(np.float)
            R = DNN.lrp(Rinit,'epsilon',1.)
            R = R.sum(axis=3)
            if not np == numpy: 
                xs = np.asnumpy(x)
                R = np.asnumpy(R)
            xs = x
            tar = xs
            a = np.load('../r_array/convolution.npy')
            a = np.reshape(a,[a.shape[1]*a.shape[2],1])
            b = np.load('../r_array/rect.npy')    
            b = np.pad(b,((0,0),(2,2),(2,2),(0,0)))
            b = np.reshape(b,[b.shape[1]*b.shape[2],b.shape[0]*b.shape[3]])
            c = np.load('../r_array/sumpoll.npy')
            c = np.pad(c,((0,0),(2,2),(2,2),(0,0)))
            c = np.reshape(c,[c.shape[1]*c.shape[2],c.shape[3]])
            new_b = np.hstack((b, c))
            new = np.hstack((a, new_b))
            tar = np.reshape(tar, [tar.shape[0]*tar.shape[1]*tar.shape[2]])
            y_tran = tar.transpose()
            new = sm.add_constant(new)

            model = sm.GLSAR(y_tran, new, rho = 2)
            result = model.iterative_fit()
            find = result.resid
            check = np.reshape(find,[1,32,32])
            
            attrs = check
            data = x
            
            attrs = scaling(attrs)
            attrs *= 255
            attrs = attrs.astype(np.uint8)
            attrs = scale(attrs)

            
        else:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            if not np == numpy:
                xs = np.asnumpy(x)
            xs = x
            
            attrs = xs
            data = x
            
#             attrs = scaling(attrs)
            attrs *= 255
            attrs = attrs.astype(np.uint8)
            attrs = scale(attrs)


        for percent in tqdm(percentiles):
            batch_attrs = attrs
            data = data
            if attribution == 'normal':
                if percent == 0.0:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_00.append(occluded_images)
                    data_00.append(data)
                    label_00.append(y)
                elif percent == 0.1:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_01.append(occluded_images)
                    data_01.append(data)
                    label_01.append(y)
                elif percent == 0.2:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_02.append(occluded_images)
                    data_02.append(data)
                    label_02.append(y)
                elif percent == 0.3:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_03.append(occluded_images)
                    data_03.append(data)
                    label_03.append(y)
                elif percent == 0.4:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_04.append(occluded_images)
                    data_04.append(data)
                    label_04.append(y)
                elif percent == 0.5:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_05.append(occluded_images)
                    data_05.append(data)
                    label_05.append(y)
                elif percent == 0.6:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_06.append(occluded_images)
                    data_06.append(data)
                    label_06.append(y)
                elif percent == 0.7:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_07.append(occluded_images)
                    data_07.append(data)
                    label_07.append(y)
                elif percent == 0.8:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_08.append(occluded_images)
                    data_08.append(data)
                    label_08.append(y)
                elif percent == 0.9:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_09.append(occluded_images)
                    data_09.append(data)
                    label_09.append(y)
                elif percent == 1:
                    occluded_images = random_remove(data, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_10.append(occluded_images)
                    data_10.append(data)
                    label_10.append(y)
                else:
                    raise ValueError("attribution?")
            else:
                if percent == 0.0:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_00.append(occluded_images)
                    data_00.append(data)
                    label_00.append(y)
                elif percent == 0.1:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_01.append(occluded_images)
                    data_01.append(data)
                    label_01.append(y)
                elif percent == 0.2:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_02.append(occluded_images)
                    data_02.append(data)
                    label_02.append(y)
                elif percent == 0.3:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_03.append(occluded_images)
                    data_03.append(data)
                    label_03.append(y)
                elif percent == 0.4:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_04.append(occluded_images)
                    data_04.append(x)
                    label_04.append(y)
                elif percent == 0.5:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_05.append(occluded_images)
                    data_05.append(x)
                    label_05.append(y)
                elif percent == 0.6:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_06.append(occluded_images)
                    data_06.append(x)
                    label_06.append(y)
                elif percent == 0.7:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_07.append(occluded_images)
                    data_07.append(x)
                    label_07.append(y)
                elif percent == 0.8:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_08.append(occluded_images)
                    data_08.append(x)
                    label_08.append(y)
                elif percent == 0.9:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_09.append(occluded_images)
                    data_09.append(x)
                    label_09.append(y)
                elif percent == 1:
                    print(" percent : {}".format(percent))
                    occluded_images = remove(data, batch_attrs, percent, keep)
                    raw_image = render.digit_to_rgb(data, scaling = 1)
                    prdigit = render.digit_to_rgb(occluded_images, scaling = 1)
                    test_image = render.save_image([raw_image, prdigit],'../{}_KAR_check_point_{}.png'.format(attribution, percent))
                    hmaps_10.append(occluded_images)
                    data_10.append(x)
                    label_10.append(y)
                else:
                    raise ValueError("Error : {}".format(percent))

        
        
        
#         hmaps.append(attrs)
#         data.append(x)
#         label.append(y)
#         print("print final : {}".format(len(hmaps)))
    print("Interpretation is done, concatenate...")
    hmaps_00 = np.concatenate(hmaps_00, axis=0)
    hmaps_01 = np.concatenate(hmaps_01, axis=0)
    hmaps_02 = np.concatenate(hmaps_02, axis=0)
    hmaps_03 = np.concatenate(hmaps_03, axis=0)
    hmaps_04 = np.concatenate(hmaps_04, axis=0)
    hmaps_05 = np.concatenate(hmaps_05, axis=0)
    hmaps_06 = np.concatenate(hmaps_06, axis=0)
    hmaps_07 = np.concatenate(hmaps_07, axis=0)
    hmaps_08 = np.concatenate(hmaps_08, axis=0)
    hmaps_09 = np.concatenate(hmaps_09, axis=0)
    hmaps_10 = np.concatenate(hmaps_10, axis=0)
    data_00 = np.concatenate(data_00, axis = 0)
    data_01 = np.concatenate(data_01, axis = 0)
    data_02 = np.concatenate(data_02, axis = 0)
    data_03 = np.concatenate(data_03, axis = 0)
    data_04 = np.concatenate(data_04, axis = 0)
    data_05 = np.concatenate(data_05, axis = 0)
    data_06 = np.concatenate(data_06, axis = 0)
    data_07 = np.concatenate(data_07, axis = 0)
    data_08 = np.concatenate(data_08, axis = 0)
    data_09 = np.concatenate(data_09, axis = 0)
    data_10 = np.concatenate(data_10, axis = 0)
    label_00 = np.concatenate(label_00, axis = 0)
    label_01 = np.concatenate(label_01, axis = 0)
    label_02 = np.concatenate(label_02, axis = 0)
    label_03 = np.concatenate(label_03, axis = 0)
    label_04 = np.concatenate(label_04, axis = 0)
    label_05 = np.concatenate(label_05, axis = 0)
    label_06 = np.concatenate(label_06, axis = 0)
    label_07 = np.concatenate(label_07, axis = 0)
    label_08 = np.concatenate(label_08, axis = 0)
    label_09 = np.concatenate(label_09, axis = 0)
    label_10 = np.concatenate(label_10, axis = 0)
    
#     hmaps = np.concatenate(hmaps, axis=0)
#     data = np.concatenate(data, axis = 0)
    
    print("concatenate is done...")
#     print("print final : {}".format(hmaps.shape))
#     percentiles = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
    for percent in tqdm(percentiles):
        if percent == 0.0:
            save(hmaps_00, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_00, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.1:
            save(hmaps_01, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_01, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.2:
            save(hmaps_02, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_02, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.3:
            save(hmaps_03, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_03, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.4:
            save(hmaps_04, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_04, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.5:
            save(hmaps_05, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_05, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.6:
            save(hmaps_06, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_06, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.7:
            save(hmaps_07, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_07, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.8:
            save(hmaps_08, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_08, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 0.9:
            save(hmaps_09, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_09, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        elif percent == 1:
            save(hmaps_10, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
            save(label_10, savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
            print("Occlude image {} percentile...".format(percent))
        else:
            print("error")
    del hmaps_00,hmaps_01,hmaps_02,hmaps_03,hmaps_04,hmaps_05,hmaps_06,hmaps_07,hmaps_08,hmaps_09,hmaps_10,data_00,data_01,data_02,data_03,data_04,data_05,data_06,data_07,data_08,data_09,data_10,label_00,label_01,label_02,label_03,label_04,label_05,label_06,label_07,label_08,label_09,label_10
Ejemplo n.º 7
0
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, model_path)
    for inx in I[:12]:
        test_x = mnist.test.images[inx]
        test_x = (test_x - 0.5) * 2
        test_y = mnist.test.labels[inx]
        relevance = sess.run(R,
                             feed_dict={
                                 x: test_x[np.newaxis, :]
                                 })
        # import pdb; pdb.set_trace()
        pred_y = sess.run(pred,
                          feed_dict={
                              x: test_x[np.newaxis, :]
                              })

        digit = render.digit_to_rgb(test_x, scaling = 3)
        hm = render.hm_to_rgb(relevance, X = test_x, scaling = 3, sigma = 2)
        digit_hm = render.save_image([digit,hm],'./heatmap.png')
        data_io.write(relevance,'./heatmap.npy')

        print ('True Class:     {}'.format(np.argmax(test_y)))
        print ('Predicted Class: {}\n'.format(np.argmax(pred_y)))

        #display the image as written to file
        plt.imshow(digit_hm, interpolation = 'none', cmap=plt.cm.binary)
        plt.axis('off')
        plt.show()
Ejemplo n.º 8
0
def occlude_dataset(DNN, attribution, percentiles, test=False, keep=False, random=False, batch_size= 128, savedir=''):
    '''
    XAI를 위한 LRP Relevance Score 도출
    percentile: Masking Percent
    test: Test만 진행 할 것인지
    keep: KAR / ROAR하기 위한 Argument
    '''
    
    print("Condition of test : {}".format(test))
    if test:
        Xs = Xtest
        ys = Ytest
    else:
        Xs = Xtrain
        ys = Ytrain
    
    print("initial batch_size is : {}".format(batch_size))
    total_batch = math.ceil(len(Xs) / batch_size)
    print("batch size is :{}".format(total_batch))
    hmaps = []
    data = []
    label = []
    
    ## Relevance Score 도출
    for i in tqdm(range(total_batch)):
        if 'LRP' in attribution:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            ypred = DNN.forward(x)

            m = np.zeros_like(ypred)
            m[:,np.argmax(ypred)] = 1
            Rinit = ypred*m
            Rinit.astype(np.float)
            R = DNN.lrp(Rinit,'epsilon',1.)
            R = R.sum(axis=3)
            if not np == numpy:
                R = np.asnumpy(R)
            if test:
                LRP_test = render.digit_to_rgb(R, scaling = 3)
            attrs = R

        elif 'proposed_method' in attribution:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            ypred = DNN.forward(x)

            m = np.zeros_like(ypred)
            m[:,np.argmax(ypred)] = 1
            Rinit = ypred*m
            Rinit.astype(np.float)
            R = DNN.lrp(Rinit,'epsilon',1.)
            R = R.sum(axis=3)
            if not np == numpy: 
                xs = np.asnumpy(x)
                R = np.asnumpy(R)
            
            ## GLS 진행
            xs = x
            tar = xs
            a = np.load('../r_array/convolution.npy')
            a = np.reshape(a,[a.shape[1]*a.shape[2],1])
            b = np.load('../r_array/rect.npy')    
            b = np.pad(b,((0,0),(2,2),(2,2),(0,0)))
            b = np.reshape(b,[b.shape[1]*b.shape[2],b.shape[0]*b.shape[3]])
            c = np.load('../r_array/sumpoll.npy')
            c = np.pad(c,((0,0),(2,2),(2,2),(0,0)))
            c = np.reshape(c,[c.shape[1]*c.shape[2],c.shape[3]])
            new_b = np.hstack((b, c))
            new = np.hstack((a, new_b))
            tar = np.reshape(tar, [tar.shape[0]*tar.shape[1]*tar.shape[2]])
            y_tran = tar.transpose()
            new = sm.add_constant(new)

            model = sm.GLSAR(y_tran, new, rho = 2)
            result = model.iterative_fit(maxiter = 30)
            find = result.resid
            check = np.reshape(find,[1,32,32])
            if test:
                proposed_test = render.digit_to_rgb(check, scaling = 3)
            attrs = check
        else:
            x = Xs[i:i+1,...]
            y = ys[i:i+1,...]
            if not np == numpy:
                xs = np.asnumpy(x)
            xs = x
            if test:
                digit = render.digit_to_rgb(xs, scaling = 3)
            attrs = xs
        attrs += np.random.normal(scale=1e-4, size=attrs.shape)
        hmaps.append(attrs)
        data.append(x)
        label.append(y)
        
    ## Heatmap 도출
    print("Interpretation is done, concatenate...")
    hmaps = np.concatenate(hmaps, axis=0)
    data = np.concatenate(data, axis = 0)
    
    print("concatenate is done...")
    print("print final : {}".format(hmaps.shape))

    for percent in tqdm(percentiles):
        batch_attrs = hmaps
        occluded_images = remove(data, batch_attrs, percent, keep)
        print("save start")
        print("Save directory is {}".format(savedir))
        save(occluded_images, savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent))
        save(np.concatenate(label, axis = 0), savedir + '{}_{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percent, 'label'))
        print("Occlude image {} percentile...".format(percent))