'classifier_path': classifier_path,
                'K': K,
                'L': L,
                'train_step': train_steps,
                'Nalpha': Nalpha,
                'Nbeta': Nbeta,
                'lam': lam,
                'batch_size': batch_size,
                'lr': lr,
                'randseed': randseed,
                'traininfo': traininfo
            })
else:  # load pretrained model
    gce = torch.load(os.path.join(gce_path, 'model.pt'), map_location=device)

# --- compute final information flow ---
I = gce.informationFlow()
Is = gce.informationFlow_singledim(range(0, K + L))
print('Information flow of K=%d causal factors on classifier output:' % K)
print(Is[:K])
print('Information flow of L=%d noncausal factors on classifier output:' % L)
print(Is[K:])

# --- generate explanation and create figure ---
sample_ind = np.concatenate(
    (np.where(vaY == 0)[0][:4], np.where(vaY == 1)[0][:4]))
x = torch.from_numpy(vaX[sample_ind])
zs_sweep = [-3., -2., -1., 0., 1., 2., 3.]
Xhats, yhats = gce.explain(x, zs_sweep)
plotting.plotExplanation(1. - Xhats, yhats, save_path='figs/fig3')
Пример #2
0
                    os.makedirs(gce_path)
                torch.save((gce, traininfo), os.path.join(gce_path, filename))
        else:  # load pretrained model
            gce, traininfo = torch.load(os.path.join(gce_path, filename))
        # get data
        gce.encoder.eval()
        gce.decoder.eval()
        torch.cuda.empty_cache()
        data['loss'][i_f, i_l, :] = traininfo['loss']
        data['loss_ce'][i_f, i_l, :] = traininfo['loss_ce']
        data['loss_nll'][i_f, i_l, :] = traininfo['loss_nll']
        data['Ijoint'][i_f, i_l] = gce.informationFlow()
        data['Is'][i_f,
                   i_l, :] = gce.informationFlow_singledim(dims=range(K + L))
        # save figures for explanation
        sample_ind = np.concatenate(
            (np.where(vaY == 0)[0][:3], np.where(vaY == 1)[0][:3],
             np.where(vaY == 2)[0][:2]))
        x = torch.from_numpy(vaX[sample_ind])
        zs_sweep = [-3., -2., -1., 0., 1., 2., 3.]
        Xhats, yhats = gce.explain(x, zs_sweep)
        if not os.path.exists('./figs/fig19/'):
            os.makedirs('./figs/fig19/')
        plotting.plotExplanation(1. - Xhats,
                                 yhats,
                                 save_path='./figs/fig19/%dfilters_lambda%g' %
                                 (nfilt, lam))
        plt.close('all')
# save all results to file
from scipy.io import savemat
savemat('./results/fig18.mat', {'data': data})
Пример #3
0
decoder = Decoder(K+L, c_dim, x_dim).to(device)
encoder.apply(util.weights_init_normal)
decoder.apply(util.weights_init_normal)

# %% train GCE
gce = GenerativeCausalExplainer(classifier, decoder, encoder, device)
traininfo = gce.train(X, K, L,
                      steps=train_steps,
                      Nalpha=Nalpha,
                      Nbeta=Nbeta,
                      lam=lam,
                      batch_size=batch_size,
                      lr=lr)
torch.save(gce, 'results/gce_fmnist.pth')
#gce = torch.load('results/gce_fmnist.pth', map_location=device)

# %%
I = gce.informationFlow()
Is = gce.informationFlow_singledim(range(0,K+L))

# %% generate explanation and create figure
sample_ind = np.concatenate((np.where(vaY == 0)[0][:3],
                             np.where(vaY == 1)[0][:3],
                             np.where(vaY == 2)[0][:2]))
x = torch.from_numpy(vaX[sample_ind])
zs_sweep = [-3., -2., -1., 0., 1., 2., 3.]
Xhats, yhats = gce.explain(x, zs_sweep)
plotting.plotExplanation(1.-Xhats, yhats, save_path='figs/fig_fmnist_qual')

# %%
Пример #4
0
def train_explainer(dataset, classes_used, K, L, lam, print_train_losses=True):
    # --- parameters ---
    # dataset
    data_classes_lst = [int(i) for i in str(classes_used)]
    dataset_name_full = dataset + '_' + str(classes_used)

    # classifier
    classifier_path = './pretrained_models/{}_classifier'.format(
        dataset_name_full)

    # GCE params
    randseed = 0
    gce_path = os.path.join(
        'outputs', dataset_name_full +
        '_gce_K{}_L{}_lambda{}'.format(K, L,
                                       str(lam).replace('.', "")))
    retrain_gce = True  # train explanatory VAE from scratch
    save_gce = True  # save/overwrite pretrained explanatory VAE at gce_path

    # other train params
    train_steps = 3000  #8000
    Nalpha = 25
    Nbeta = 70
    batch_size = 64
    lr = 5e-4

    # --- initialize ---
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if randseed is not None:
        np.random.seed(randseed)
        torch.manual_seed(randseed)
    ylabels = range(0, len(data_classes_lst))

    # --- load data ---
    from load_mnist import load_mnist_classSelect, load_fashion_mnist_classSelect

    if dataset == 'mnist':
        fn = load_mnist_classSelect
    elif dataset == 'fmnist':
        fn = load_fashion_mnist_classSelect
    elif dataset == 'cifar':
        fn = load_cifar_classSelect
    else:
        print('dataset not correctly specified')

    X, Y, tridx = fn('train', data_classes_lst, ylabels)
    vaX, vaY, vaidx = fn('val', data_classes_lst, ylabels)

    if dataset == "cifar":
        X, vaX = X / 255, vaX / 255

    ntrain, nrow, ncol, c_dim = X.shape
    x_dim = nrow * ncol

    # --- load classifier ---
    from models.CNN_classifier import CNN
    classifier = CNN(len(data_classes_lst), c_dim).to(device)
    checkpoint = torch.load('%s/model.pt' % classifier_path,
                            map_location=device)
    classifier.load_state_dict(checkpoint['model_state_dict_classifier'])

    # --- train/load GCE ---
    from models.CVAEImageNet import Decoder, Encoder
    if retrain_gce:
        encoder = Encoder(K + L, c_dim, x_dim).to(device)
        decoder = Decoder(K + L, c_dim, x_dim).to(device)
        encoder.apply(util.weights_init_normal)
        decoder.apply(util.weights_init_normal)
        gce = GenerativeCausalExplainer(classifier,
                                        decoder,
                                        encoder,
                                        device,
                                        save_output=True,
                                        save_model_params=False,
                                        save_dir=gce_path,
                                        debug_print=print_train_losses)
        traininfo = gce.train(X,
                              K,
                              L,
                              steps=train_steps,
                              Nalpha=Nalpha,
                              Nbeta=Nbeta,
                              lam=lam,
                              batch_size=batch_size,
                              lr=lr)
        if save_gce:
            if not os.path.exists(gce_path):
                os.makedirs(gce_path)
            torch.save(gce, os.path.join(gce_path, 'model.pt'))
            sio.savemat(
                os.path.join(gce_path, 'training-info.mat'), {
                    'data_classes_lst': data_classes_lst,
                    'classifier_path': classifier_path,
                    'K': K,
                    'L': L,
                    'train_step': train_steps,
                    'Nalpha': Nalpha,
                    'Nbeta': Nbeta,
                    'lam': lam,
                    'batch_size': batch_size,
                    'lr': lr,
                    'randseed': randseed,
                    'traininfo': traininfo
                })
    else:  # load pretrained model
        gce = torch.load(os.path.join(gce_path, 'model.pt'),
                         map_location=device)
        traininfo = None

    # --- compute final information flow ---
    I = gce.informationFlow()
    Is = gce.informationFlow_singledim(range(0, K + L))
    print('Information flow of K=%d causal factors on classifier output:' % K)
    print(Is[:K])
    print('Information flow of L=%d noncausal factors on classifier output:' %
          L)
    print(Is[K:])

    # --- generate explanation and create figure ---
    nr_labels = len(data_classes_lst)
    nr_samples_fig = 8
    sample_ind = np.empty(0, dtype=int)

    # retrieve samples from each class
    samples_per_class = math.ceil(nr_samples_fig / nr_labels)
    for i in range(nr_labels):
        samples_per_class = math.ceil(
            (nr_samples_fig - i * samples_per_class) / (nr_labels - i))
        sample_ind = np.int_(
            np.concatenate(
                [sample_ind,
                 np.where(vaY == i)[0][:samples_per_class]]))

    x = torch.from_numpy(vaX[sample_ind])
    zs_sweep = [-3., -2., -1., 0., 1., 2., 3.]
    Xhats, yhats = gce.explain(x, zs_sweep)
    plot_save_dir = os.path.join(gce_path, 'figs/')
    if not os.path.exists(plot_save_dir):
        os.makedirs(plot_save_dir)
    plotting.plotExplanation(1. - Xhats, yhats, save_path=plot_save_dir)

    return traininfo
Пример #5
0
                          batch_size=batch_size,
                          lr=lr)
    if save_gce:
        if not os.path.exists(gce_path):
            os.makedirs(gce_path)
        torch.save(gce, os.path.join(gce_path,'model.pt'))
        sio.savemat(os.path.join(gce_path, 'training-info.mat'), {
            'data_classes' : data_classes, 'classifier_path' : classifier_path,
            'K' : K, 'L' : L, 'train_step' : train_steps, 'Nalpha' : Nalpha,
            'Nbeta' : Nbeta, 'lam' : lam, 'batch_size' : batch_size, 'lr' : lr,
            'randseed' : randseed, 'traininfo' : traininfo})
else: # load pretrained model
    gce = torch.load(os.path.join(gce_path, 'model.pt'), map_location=device)

# --- compute final information flow ---
I = gce.informationFlow()
Is = gce.informationFlow_singledim(range(0, K+L))
print('Information flow of K=%d causal factors on classifier output:' % K)
print(Is[:K])
print('Information flow of L=%d noncausal factors on classifier output:' % L)
print(Is[K:])


# --- generate explanation and create figure ---
sample_ind = np.concatenate((np.where(vaY == 0)[0][:4],
                             np.where(vaY == 1)[0][:4]))
x = torch.from_numpy(vaX[sample_ind])
zs_sweep = [-3., -2., -1., 0., 1., 2., 3.]
Xhats, yhats = gce.explain(x, zs_sweep)
plotting.plotExplanation(1. - Xhats, yhats, save_path='/Fig3CIFAR')
                              lam=lam,
                              batch_size=batch_size,
                              lr=lr)
        # get data
        gce.encoder.eval()
        gce.decoder.eval()
        torch.cuda.empty_cache()
        torch.save(gce, 'results/gce_vae_capacity_%dfilters_lambda%g.pth' \
            % (nfilt, lam))
        data['loss'][i_f, i_l, :] = traininfo['loss']
        data['loss_ce'][i_f, i_l, :] = traininfo['loss_ce']
        data['loss_nll'][i_f, i_l, :] = traininfo['loss_nll']
        data['Ijoint'][i_f, i_l] = gce.informationFlow()
        data['Is'][i_f,
                   i_l, :] = gce.informationFlow_singledim(dims=range(K + L))
        # save figures for explanation
        sample_ind = np.concatenate(
            (np.where(vaY == 0)[0][:3], np.where(vaY == 1)[0][:3],
             np.where(vaY == 2)[0][:2]))
        x = torch.from_numpy(vaX[sample_ind])
        zs_sweep = [-3., -2., -1., 0., 1., 2., 3.]
        Xhats, yhats = gce.explain(x, zs_sweep)
        plotting.plotExplanation(
            1. - Xhats,
            yhats,
            save_path='figs/fig_vae_capacity_%dfilters_lambda%g' %
            (nfilt, lam))
# save all results to file
from scipy.io import savemat
savemat('results/vae_capacity_data.mat', {'data': data})