示例#1
0
sample_inputs_torch = torch.from_numpy(sample_inputs)
sample_inputs_torch = sample_inputs_torch.permute(0,3,1,2).float().to(device)     
ntrain = X.shape[0]

# --- load VAE ---
from VAEModel_CNN import Decoder, Encoder
checkpoint_vae = torch.load(vae_file, map_location=device)
encoder = Encoder(z_dim,c_dim,img_size).to(device)
decoder = Decoder(z_dim,c_dim,img_size).to(device)
encoder.load_state_dict(checkpoint_vae['model_state_dict_encoder'])
decoder.load_state_dict(checkpoint_vae['model_state_dict_decoder'])

# --- load classifier ---
from cnnClassifierModel import CNN
checkpoint_model = torch.load(classifier_file, map_location=device)
classifier = CNN(y_dim).to(device)
classifier.load_state_dict(checkpoint_model['model_state_dict_classifier'])


#%% generate latent factor sweep plot

sample_ind = np.concatenate((np.where(vaY == 0)[0][:1],
                             np.where(vaY == 1)[0][:1]))
cols = [[0.047,0.482,0.863],[1.000,0.761,0.039],[0.561,0.788,0.227]]
border_size = 0
nsamples = len(sample_ind)
latentsweep_vals = [-3., -2., -1., 0., 1., 2., 3.]
Xhats = np.zeros((z_dim,nsamples,len(latentsweep_vals),img_size,img_size,1))
yhats = np.zeros((z_dim,nsamples,len(latentsweep_vals)))
for isamp in range(nsamples):
    x = torch.from_numpy(np.expand_dims(vaX[sample_ind[isamp]],0))
示例#2
0
def CVAE(
        model="mnist_VAE_CNN",  # currently not used
        steps=20000,
        batch_size=100,
        z_dim=6,
        z_dim_true=4,
        x_dim=10,
        y_dim=1,
        alpha_dim=4,
        ntrain=5000,
        No=15,
        Ni=15,
        lam_ML=0.000001,
        gamma=0.001,
        lr=0.0001,
        b1=0.5,
        b2=0.999,
        use_ce=True,
        objective="IND_UNCOND",
        decoder_net="VAE_CNN",  # options are ['linGauss','nonLinGauss','VAE','VAE_CNN','VAE_Imagenet','VAE_fMNIST']
        classifier_net="cnn",  # options are ['oneHyperplane','twoHyperplane','cnn','cnn_imagenet','cnn_fmnist']
        data_type="mnist",  # options are ["2dpts","mnist","imagenet","fmnist"]
        break_up_ce=True,  # Whether or not to break up the forward passes of the network based on alphas
        randseed=None,
        save_output=False,
        debug_level=2,
        debug_plot=False,
        save_plot=False,
        c_dim=1,
        img_size=28):

    # initialization
    params = {
        "steps": steps,
        "batch_size": batch_size,
        "z_dim": z_dim,
        "z_dim_true": z_dim_true,
        "x_dim": x_dim,
        "y_dim": y_dim,
        "alpha_dim": alpha_dim,
        "ntrain": ntrain,
        "No": No,
        "Ni": Ni,
        "lam_ML": lam_ML,
        "gamma": gamma,
        "lr": lr,
        "b1": b1,
        "b2": b2,
        "use_ce": use_ce,
        "objective": objective,
        "decoder_net": decoder_net,
        "classifier_net": classifier_net,
        "data_type": data_type,
        "break_up_ce": break_up_ce,
        "randseed": randseed,
        "save_output": save_output,
        "debug_level": debug_level,
        'c_dim': c_dim,
        'img_size': img_size
    }
    params["data_std"] = 2.
    if debug_level > 0:
        print("Parameters:")
        print(params)

    # Initialize arrays for storing performance data
    debug = {}
    debug["loss"] = np.zeros((steps))
    debug["loss_ce"] = np.zeros((steps))
    debug["loss_nll"] = np.zeros((steps))
    debug["loss_nll_logdet"] = np.zeros((steps))
    debug["loss_nll_quadform"] = np.zeros((steps))
    debug["loss_nll_mse"] = np.zeros((steps))
    debug["loss_nll_kld"] = np.zeros((steps))
    if decoder_net == 'linGauss':
        for i in range(params["z_dim"]):
            for j in range(params["z_dim_true"]):
                debug["cossim_w%dwhat%d" % (j + 1, i + 1)] = np.zeros((steps))
            for j in range(i + 1, params["z_dim"]):
                debug["cossim_what%dwhat%d" % (i + 1, j + 1)] = np.zeros(
                    (steps))
    if save_plot: frames = []
    if data_type == 'mnist' or data_type == 'fmnist':
        class_use = np.array([0, 3, 4])
        class_use_str = np.array2string(class_use)
        y_dim = class_use.shape[0]
        newClass = range(0, y_dim)
        save_dir = '/home/mnorko/Documents/Tensorflow/causal_vae/results/fmnist_class034/' + data_type + '_' + objective + '_zdim' + str(
            z_dim) + '_alpha' + str(alpha_dim) + '_No' + str(No) + '_Ni' + str(
                Ni) + '_lam' + str(lam_ML) + '_class' + class_use_str[1:(
                    len(class_use_str) - 1):2] + '/'
    else:
        save_dir = '/home/mnorko/Documents/Tensorflow/causal_vae/results/imagenet/' + data_type + '_' + objective + '_zdim' + str(
            z_dim) + '_alpha' + str(alpha_dim) + '_No' + str(No) + '_Ni' + str(
                Ni) + '_lam' + str(lam_ML) + '_cont_kl0.1/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if data_type == '2dpts':
        break_up_ce = False
        params['break_up_ce'] = False

    # seed random number generator
    if randseed is not None:
        if debug_level > 0:
            print('Setting random seed to ' + str(randseed) + '.')
        np.random.seed(randseed)
        torch.manual_seed(randseed)
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Generate data
    if data_type == '2dpts':
        # --- construct projection matrices ---
        # 'true' orthogonal columns used to generate data from latent factors
        #Wsquare = sp.linalg.orth(np.random.rand(x_dim,x_dim))
        Wsquare = np.identity(x_dim)
        W = Wsquare[:, :z_dim_true]
        # 1st column of W
        w1 = np.expand_dims(W[:, 0], axis=1)
        # 2nd column of W
        w2 = np.expand_dims(W[:, 1], axis=1)
        # form projection matrices
        Pw1 = util.formProjMat(w1)
        Pw2 = util.formProjMat(w2)
        # convert to torch matrices
        Pw1_torch = torch.from_numpy(Pw1).float()
        Pw2_torch = torch.from_numpy(Pw2).float()

        # --- construct data ---
        # ntrain instances of alpha and x
        Alpha = params["data_std"] * np.random.randn(ntrain, z_dim_true)
        X = np.matmul(Alpha, W.T)
    elif data_type == 'mnist':
        test_size = 64
        X, Y, tridx = load_mnist_classSelect('train', class_use, newClass)
        vaX, vaY, vaidx = load_mnist_classSelect('val', class_use, newClass)
        sample_inputs = vaX[0:test_size]
        sample_inputs_torch = torch.from_numpy(sample_inputs)
        sample_inputs_torch = sample_inputs_torch.permute(0, 3, 1,
                                                          2).float().to(device)
        ntrain = X.shape[0]
    elif data_type == 'fmnist':
        test_size = 64
        X, Y, tridx = load_fashion_mnist_classSelect('train', class_use,
                                                     newClass)
        vaX, vaY, vaidx = load_fashion_mnist_classSelect(
            'val', class_use, newClass)
        sample_inputs = vaX[0:test_size]
        sample_inputs_torch = torch.from_numpy(sample_inputs)
        sample_inputs_torch = sample_inputs_torch.permute(0, 3, 1,
                                                          2).float().to(device)
        ntrain = X.shape[0]
    elif data_type == 'svhn':
        X, Y, tridx = load_svhn_classSelect('train', class_use, newClass)
        vaX, vaY, vaidx = load_svhn_classSelect('val', class_use, newClass)
        sample_inputs = vaX[0:test_size]
        sample_inputs_torch = torch.from_numpy(sample_inputs)
        sample_inputs_torch = sample_inputs_torch.float().to(device)
        ntrain = X.shape[0]
    elif data_type == 'imagenet':
        transform_train = transforms.Compose([
            transforms.RandomCrop(128, padding=4, pad_if_needed=True),
            transforms.RandomHorizontalFlip(),
            transforms.Resize([32, 32]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        transform_test = transforms.Compose([
            transforms.CenterCrop(128),
            transforms.Resize([32, 32]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        file_open = open('imagenet_zebra_gorilla.pkl', 'rb')
        train_fileType_2 = pickle.load(file_open)
        train_imgName_2 = pickle.load(file_open)
        train_imgLabel_2 = pickle.load(file_open)

        val_fileType_2 = pickle.load(file_open)
        val_imgName_2 = pickle.load(file_open)
        val_imgLabel_2 = pickle.load(file_open)

        file_open.close()
        train_set = Imagenet_Gor_Zeb(train_imgName_2,
                                     train_imgLabel_2,
                                     train_fileType_2,
                                     transforms=transform_train)

        trainloader = torch.utils.data.DataLoader(train_set,
                                                  batch_size=batch_size,
                                                  shuffle=True)
        test_size = 16
        val_set = Imagenet_Gor_Zeb(val_imgName_2,
                                   val_imgLabel_2,
                                   val_fileType_2,
                                   transforms=transform_test)
        valloader = torch.utils.data.DataLoader(val_set,
                                                batch_size=test_size,
                                                shuffle=True)

        dataiter = iter(trainloader)
        dataiter_val = iter(valloader)
        sample_inputs_torch, _ = dataiter_val.next()
        #sample_inputs_torch.to(device)
        sample_inputs_torch = sample_inputs_torch.to(device)

    # --- initialize decoder ---
    if decoder_net == 'linGauss':
        from linGaussModel import Decoder
        decoder = Decoder(x_dim, z_dim).to(device)
    elif decoder_net == 'nonLinGauss':
        from VAEModel import Decoder_samp
        z_num_samp = No
        decoder = Decoder_samp(x_dim, z_dim).to(device)
    elif decoder_net == 'VAE':
        from VAEModel import Decoder, Encoder
        encoder = Encoder(x_dim, z_dim).to(device)
        encoder.apply(weights_init_normal)
        decoder = Decoder(x_dim, z_dim).to(device)
        decoder.apply(weights_init_normal)
    elif decoder_net == 'VAE_CNN' or 'VAE_fMNIST':
        from VAEModel_CNN import Decoder, Encoder
        encoder = Encoder(z_dim, c_dim, img_size).to(device)
        encoder.apply(weights_init_normal)
        decoder = Decoder(z_dim, c_dim, img_size).to(device)
        decoder.apply(weights_init_normal)
    elif decoder_net == 'VAE_Imagenet':
        checkpoint = torch.load(
            '/home/mnorko/Documents/Tensorflow/causal_vae/results/imagenet/imagenet_JOINT_UNCOND_zdim40_alpha0_No20_Ni1_lam0.001/'
            + 'network_batch' + str(batch_size) + '.pt')
        from VAEModel_CNN_imagenet import Decoder, Encoder
        encoder = Encoder(z_dim, c_dim, img_size).to(device)
        decoder = Decoder(z_dim, c_dim, img_size).to(device)
        encoder.load_state_dict(checkpoint['model_state_dict_encoder'])
        decoder.load_state_dict(checkpoint['model_state_dict_decoder'])
    else:
        print("Error: decoder_net should be one of: linGauss nonLinGauss VAE")

    # --- initialize classifier ---
    if classifier_net == 'oneHyperplane':
        from hyperplaneClassifierModel import OneHyperplaneClassifier
        classifier = OneHyperplaneClassifier(x_dim, y_dim, Pw1_torch,
                                             ksig=5.).to(device)
        classifier.apply(weights_init_normal)
    elif classifier_net == 'twoHyperplane':
        from hyperplaneClassifierModel import TwoHyperplaneClassifier
        classifier = TwoHyperplaneClassifier(x_dim,
                                             y_dim,
                                             Pw1_torch,
                                             Pw2_torch,
                                             ksig=5.).to(device)
        classifier.apply(weights_init_normal)
    elif classifier_net == 'cnn':
        from cnnClassifierModel import CNN
        classifier = CNN(y_dim).to(device)
        batch_orig = 64
        checkpoint = torch.load('./mnist_batch64_lr0.1_class38/network_batch' +
                                str(batch_orig) + '.pt')
        #checkpoint = torch.load('/home/mnorko/Documents/Tensorflow/causal_vae/results/mnist_batch64_lr0.1_class149/network_batch' + str(batch_orig) + '.pt')
        classifier.load_state_dict(checkpoint['model_state_dict_classifier'])
    elif classifier_net == 'cnn_fmnist':
        from cnnClassifierModel import CNN
        classifier = CNN(y_dim).to(device)
        batch_orig = 64
        checkpoint = torch.load(
            './fmnist_batch64_lr0.1_class034/network_batch' + str(batch_orig) +
            '.pt')
        #checkpoint = torch.load('/home/mnorko/Documents/Tensorflow/causal_vae/results/mnist_batch64_lr0.1_class149/network_batch' + str(batch_orig) + '.pt')
        classifier.load_state_dict(checkpoint['model_state_dict_classifier'])
    elif classifier_net == 'cnn_imagenet':
        from cnnImageNetClassifierModel import CNN
        classImagenetIdx = [366, 340]
        classifier_model = models.vgg16_bn(pretrained=True)
        classifier = CNN(classifier_model, classImagenetIdx).to(device)
    else:
        print(
            "Error: classifier should be one of: oneHyperplane twoHyperplane")

    ##

    if decoder_net == 'linGauss':
        What = Variable(torch.mul(torch.randn(x_dim, z_dim, dtype=torch.float),
                                  0.5),
                        requires_grad=True)
    else:
        What = None

    # --- specify optimizer ---
    # NOTE: we only include the decoder parameters in the optimizer
    # because we don't want to update the classifier parameters
    if decoder_net == 'VAE' or decoder_net == 'VAE_CNN' or decoder_net == 'VAE_Imagenet' or decoder_net == 'VAE_fMNIST':
        params_use = list(decoder.parameters()) + list(encoder.parameters())
    else:
        params_use = list(decoder.parameters())
    optimizer_NN = torch.optim.Adam(params_use, lr=lr, betas=(b1, b2))

    # --- train ---
    start_time = time.time()
    for k in range(0, steps):

        # --- reset gradients to zero ---
        # (you always need to do this in pytorch or the gradients are
        # accumulated from one batch to the next)
        optimizer_NN.zero_grad()

        # --- compute negative log likelihood ---
        # randomly subsample batch_size samples of x

        randIdx = np.random.randint(0, ntrain, batch_size)
        if data_type == '2dpts':
            Xbatch = torch.from_numpy(X[randIdx, :]).float()
        elif data_type == 'mnist' or data_type == 'fmnist':
            Xbatch = torch.from_numpy(X[randIdx]).float()
            Xbatch = Xbatch.permute(0, 3, 1, 2)
        elif data_type == 'imagenet':
            try:
                Xbatch, _ = dataiter.next()
            except:
                dataiter = iter(trainloader)
                Xbatch, _ = dataiter.next()
        Xbatch = Xbatch.to(device)
        if decoder_net == 'linGauss':
            nll = loss_functions.linGauss_NLL_loss(Xbatch, What, gamma)
        elif decoder_net == 'nonLinGauss':
            randBatch = torch.from_numpy(np.random.randn(z_num_samp,
                                                         z_dim)).float()
            Xest, Xmu, Xlogvar = decoder(randBatch)
            Xcov = torch.exp(Xlogvar)
            nll = loss_functions.nonLinGauss_NLL_loss(Xbatch, Xmu, Xcov)
        elif decoder_net == 'VAE' or decoder_net == 'VAE_CNN' or decoder_net == 'VAE_Imagenet' or decoder_net == 'VAE_fMNIST':
            latent_out, mu, logvar = encoder(Xbatch)
            Xest = decoder(latent_out)
            nll, nll_mse, nll_kld = loss_functions.VAE_LL_loss(
                Xbatch, Xest, logvar, mu)

        # --- compute mutual information causal effect term ---
        if objective == "IND_UNCOND":
            causalEffect, ceDebug = causaleffect.ind_uncond(params,
                                                            decoder,
                                                            classifier,
                                                            device,
                                                            What=What)
        elif objective == "IND_COND":
            causalEffect, ceDebug = causaleffect.ind_cond(params,
                                                          decoder,
                                                          classifier,
                                                          device,
                                                          What=What)
        elif objective == "JOINT_UNCOND":
            causalEffect, ceDebug = causaleffect.joint_uncond(params,
                                                              decoder,
                                                              classifier,
                                                              device,
                                                              What=What)
        elif objective == "JOINT_COND":
            causalEffect, ceDebug = causaleffect.joint_cond(params,
                                                            decoder,
                                                            classifier,
                                                            device,
                                                            What=What)

        # --- compute gradients ---
        # total loss
        loss = use_ce * causalEffect + lam_ML * nll
        # backward step to compute the gradients
        loss.backward()
        if decoder_net == 'linGauss':
            # update What with the computed gradient
            What.data.sub_(lr * What.grad.data)
            # reset the What gradients to 0
            What.grad.data.zero_()
        else:
            optimizer_NN.step()

        # --- save debug info for this step ---
        debug["loss"][k] = loss.item()
        debug["loss_ce"][k] = causalEffect.item()
        debug["loss_nll"][k] = (lam_ML * nll).item()
        if decoder_net == 'VAE' or decoder_net == 'VAE_CNN' or decoder_net == 'VAE_Imagenet' or decoder_net == 'VAE_fMNIST':
            debug["loss_nll_mse"][k] = (lam_ML * nll_mse).item()
            debug["loss_nll_kld"][k] = (lam_ML * nll_kld).item()
        if debug_level > 1:
            if decoder_net == 'linGauss':
                Wnorm = W / np.linalg.norm(W, axis=0, keepdims=True)
                Whatnorm = What.detach().numpy() / np.linalg.norm(
                    What.detach().numpy(), axis=0, keepdims=True)
                # cosine similarities between columns of W and What
                for i in range(params["z_dim"]):
                    for j in range(params["z_dim_true"]):
                        debug["cossim_w%dwhat%d" %
                              (j + 1, i + 1)][k] = np.matmul(
                                  Wnorm[:, j], Whatnorm[:, i])
                    for j in range(i + 1, params["z_dim"]):
                        debug["cossim_what%dwhat%d" %
                              (i + 1, j + 1)][k] = np.matmul(
                                  Whatnorm[:, i], Whatnorm[:, j])

        # --- print step information ---
        if debug_level > 0:
            print("[Step %d/%d] time: %4.2f  [CE: %g] [ML: %g] [loss: %g]" % \
                  (k, steps, time.time() - start_time, debug["loss_ce"][k],
                   debug["loss_nll"][k], debug["loss"][k]))

        # --- debug plot ---
        if debug_plot and k % 1000 == 0:
            print('Generating plot frame...')
            if data_type == '2dpts':
                # generate samples of p(x | alpha_i = alphahat_i)
                decoded_points = {}
                decoded_points["ai_vals"] = lfplot_aihat_vals
                decoded_points["samples"] = np.zeros(
                    (2, lfplot_nsamp, len(lfplot_aihat_vals), params["z_dim"]))
                for l in range(params["z_dim"]):  # loop over latent dimensions
                    for i, aihat in enumerate(
                            lfplot_aihat_vals):  # loop over fixed aihat values
                        for m in range(
                                lfplot_nsamp):  # loop over samples to generate
                            z = np.random.randn(params["z_dim"])
                            z[l] = aihat
                            x = decoder(
                                torch.from_numpy(z).float(), What, gamma)
                            decoded_points["samples"][:, m, i,
                                                      l] = x.detach().numpy()
                frame = plotting.debugPlot_frame(X, ceDebug["Xhat"], W, What,
                                                 k, steps, debug, params,
                                                 classifier, decoded_points)
                if save_plot:
                    frames.append(frame)
            elif data_type == 'mnist' or data_type == 'imagenet' or data_type == 'fmnist':
                torch.save(
                    {
                        'step': k,
                        'model_state_dict_classifier': classifier.state_dict(),
                        'model_state_dict_encoder': encoder.state_dict(),
                        'model_state_dict_decoder': decoder.state_dict(),
                        'optimizer_state_dict': optimizer_NN.state_dict(),
                        'loss': loss,
                    }, save_dir + 'network_batch' + str(batch_size) + '.pt')
                sample_latent, mu, var = encoder(sample_inputs_torch)
                sample_inputs_torch_new = sample_inputs_torch.permute(
                    0, 2, 3, 1)
                sample_inputs_np = sample_inputs_torch_new.detach().cpu(
                ).numpy()
                sample_img = decoder(sample_latent)
                sample_latent_small = sample_latent[0:10, :]
                imgOut_real, probOut_real, latentOut_real = sweepLatentFactors(
                    sample_latent_small, decoder, classifier, device, img_size,
                    c_dim, y_dim, False)
                rand_latent = torch.from_numpy(np.random.randn(
                    10, z_dim)).float().to(device)
                imgOut_rand, probOut_rand, latentOut_rand = sweepLatentFactors(
                    rand_latent, decoder, classifier, device, img_size, c_dim,
                    y_dim, False)
                samples = sample_img
                samples = samples.permute(0, 2, 3, 1)
                samples = samples.detach().cpu().numpy()
                save_images(samples, [8, 8],
                            '{}train_{:04d}.png'.format(save_dir, k))
                #sio.savemat(save_dir + 'sweepLatentFactors.mat',{'imgOut_real':imgOut_real,'probOut_real':probOut_real,'latentOut_real':latentOut_real,'loss_total':debug["loss"][:k],'loss_ce':debug["loss_ce"][:k],'loss_nll':debug['loss_nll'][:k],'samples_out':samples,'sample_inputs':sample_inputs_np})
                sio.savemat(
                    save_dir + 'sweepLatentFactors.mat', {
                        'imgOut_real': imgOut_real,
                        'probOut_real': probOut_real,
                        'latentOut_real': latentOut_real,
                        'imgOut_rand': imgOut_rand,
                        'probOut_rand': probOut_rand,
                        'latentOut_rand': latentOut_rand,
                        'loss_total': debug["loss"][:k],
                        'loss_ce': debug["loss_ce"][:k],
                        'loss_nll': debug['loss_nll'][:k],
                        'samples_out': samples,
                        'sample_inputs': sample_inputs_np
                    })

    # --- save all debug data ---
    debug["X"] = Xbatch.detach().cpu().numpy()
    if not decoder_net == 'linGauss':
        debug["Xest"] = Xest.detach().cpu().numpy()
    if save_output:
        datestamp = ''.join(
            re.findall(r'\d+',
                       str(datetime.datetime.now())[:10]))
        timestamp = ''.join(
            re.findall(r'\d+',
                       str(datetime.datetime.now())[11:19]))
        #        results_folder = './results/tests_kSig5_lr0001_' + objective + '_lam' \
        #            + str(lam_ML) + '_No' + str(No) + '_Ni' + str(Ni) + '_' \
        #            + datestamp + '_' + timestamp + '/'
        #        if not os.path.exists(results_folder):
        #            os.makedirs(results_folder)
        results_folder = save_dir
        matfilename = 'results_' + datestamp + '_' + timestamp + '.mat'
        sio.savemat(results_folder + matfilename, {
            'params': params,
            'data': debug
        })
        if debug_level > 0:
            print('Finished saving data to ' + matfilename)

    if save_plot:
        print('Saving plot...')
        gif.save(frames, "results.gif", duration=100)
        print('Done!')

    return debug
def CVAE(model = "mnist_VAE_CNN", # currently not used
         steps          = 20000,
         batch_size     = 32,
         z_dim          = 6,
         z_dim_true     = 4,
         x_dim          = 10,
         y_dim          = 1,
         alpha_dim      = 4,
         ntrain         = 5000,
         No             = 15,
         Ni             = 15,
         lam_ML         = 0.000001,
         gamma          = 0.001,
         lr             = 0.0001,
         b1             = 0.5,
         b2             = 0.999,
         use_ce         = True,
         objective      = "IND_UNCOND",
         decoder_net    = "VAE_CNN", # options are ['linGauss','nonLinGauss','VAE','VAE_CNN','VAE_Imagenet']
         classifier_net = "cnn", # options are ['oneHyperplane','twoHyperplane','cnn','cnn_imagenet']
         data_type      = "mnist", # options are ["2dpts","mnist","imagenet"]
         break_up_ce    = True, # Whether or not to break up the forward passes of the network based on alphas
         randseed       = None,
         save_output    = False,
         debug_level    = 2,
         debug_plot     = False,
         save_plot      = False,
         c_dim          = 1,
         img_size       = 28):

    # initialization
    params = {
        "steps"          : steps,
        "batch_size"     : batch_size,
        "z_dim"          : z_dim,
        "z_dim_true"     : z_dim_true,
        "x_dim"          : x_dim,
        "y_dim"          : y_dim,
        "alpha_dim"      : alpha_dim,
        "ntrain"         : ntrain,
        "No"             : No,
        "Ni"             : Ni,
        "lam_ML"         : lam_ML,
        "gamma"          : gamma,
        "lr"             : lr,
        "b1"             : b1,
        "b2"             : b2,
        "use_ce"         : use_ce,
        "objective"      : objective,
        "decoder_net"    : decoder_net,
        "classifier_net" : classifier_net,
        "data_type"      : data_type,
        "break_up_ce"    : break_up_ce,
        "randseed"       : randseed,
        "save_output"    : save_output,
        "debug_level"    : debug_level,
        'c_dim'          : c_dim,
        'img_size'       : img_size}
    params["data_std"] = 2.
    if debug_level > 0:
        print("Parameters:")
        print(params)
    
    # Initialize arrays for storing performance data
    debug = {}
    debug["loss"]              = np.zeros((steps))
    debug["loss_ce"]           = np.zeros((steps))
    debug["loss_nll"]          = np.zeros((steps))
    debug["loss_nll_logdet"]   = np.zeros((steps))
    debug["loss_nll_quadform"] = np.zeros((steps))
    debug["loss_nll_mse"]      = np.zeros((steps))
    debug["loss_nll_kld"]      = np.zeros((steps))
    if not decoder_net == 'linGauss':
        for i in range(params["z_dim"]):
            for j in range(params["z_dim_true"]):
                debug["cossim_w%dwhat%d"%(j+1,i+1)] = np.zeros((steps))
            for j in range(i+1,params["z_dim"]):
                debug["cossim_what%dwhat%d"%(i+1,j+1)] = np.zeros((steps))
    if save_plot: frames = []
    if data_type == 'mnist':
        class_use = np.array([3,8])
        class_use_str = np.array2string(class_use)    
        y_dim = class_use.shape[0]
        newClass = range(0,y_dim)
        save_dir = '/home/mnorko/Documents/Tensorflow/causal_vae/results/imagenet/' + data_type + '_' + objective + '_zdim' + str(z_dim) + '_alpha' + str(alpha_dim) + '_No' + str(No) + '_Ni' + str(Ni) + '_lam' + str(lam_ML) + '_class' + class_use_str[1:(len(class_use_str)-1):2] + '/'
    else:
        save_dir = '/home/mnorko/Documents/Tensorflow/causal_vae/results/imagenet/' + data_type + '_' + objective + '_zdim' + str(z_dim) + '_alpha' + str(alpha_dim) + '_No' + str(No) + '_Ni' + str(Ni) + '_lam' + str(lam_ML)  + '/'

        
    if data_type == '2dpts':
        break_up_ce = False
        params['break_up_ce'] = False

    # seed random number generator
    if randseed is not None:
        if debug_level > 0:
            print('Setting random seed to ' + str(randseed) + '.')
        np.random.seed(randseed)
        torch.manual_seed(randseed)
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
    # Generate data
    if data_type == '2dpts':    
        # --- construct projection matrices ---
        # 'true' orthogonal columns used to generate data from latent factors
        #Wsquare = sp.linalg.orth(np.random.rand(x_dim,x_dim))
        Wsquare = np.identity(x_dim)
        W = Wsquare[:,:z_dim_true]
        # 1st column of W
        w1 = np.expand_dims(W[:,0], axis=1)
        # 2nd column of W
        w2 = np.expand_dims(W[:,1], axis=1)
        # form projection matrices
        Pw1 = util.formProjMat(w1)
        Pw2 = util.formProjMat(w2)
        # convert to torch matrices
        Pw1_torch = torch.from_numpy(Pw1).float()
        Pw2_torch = torch.from_numpy(Pw2).float()
        
        # --- construct data ---
        # ntrain instances of alpha and x
        Alpha = params["data_std"]*np.random.randn(ntrain, z_dim_true)
        X = np.matmul(Alpha, W.T)
    elif data_type == 'mnist':
        test_size = 64
        X, Y, tridx = load_mnist_classSelect('train',class_use,newClass)
        vaX, vaY, vaidx = load_mnist_classSelect('val',class_use,newClass)
        sample_inputs = vaX[0:test_size]
        sample_inputs_torch = torch.from_numpy(sample_inputs)
        sample_inputs_torch = sample_inputs_torch.permute(0,3,1,2).float().to(device)        
        ntrain = X.shape[0]
    elif data_type == 'imagenet':
        transform_train = transforms.Compose([
            transforms.RandomCrop(128, padding=4,pad_if_needed = True),
            transforms.RandomHorizontalFlip(),
            transforms.Resize([32,32]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
        ])
        
        
        transform_test = transforms.Compose([
            transforms.CenterCrop(128),
            transforms.Resize([32,32]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
        ])
        
        
        file_open = open('imagenet_zebra_gorilla.pkl','rb')
        train_fileType_2 = pickle.load(file_open)
        train_imgName_2 = pickle.load(file_open)
        train_imgLabel_2 = pickle.load(file_open)

        val_fileType_2 =pickle.load(file_open)
        val_imgName_2 = pickle.load(file_open)
        val_imgLabel_2 =pickle.load(file_open)

        file_open.close()
        train_set = Imagenet_Gor_Zeb(train_imgName_2,train_imgLabel_2,train_fileType_2,transforms = transform_train)
        
        trainloader = torch.utils.data.DataLoader(train_set,batch_size = batch_size,shuffle= True)
        test_size = 20
        val_set = Imagenet_Gor_Zeb(val_imgName_2,val_imgLabel_2,val_fileType_2,transforms = transform_test) 
        valloader = torch.utils.data.DataLoader(val_set,batch_size = test_size,shuffle= True)
        
        dataiter = iter(trainloader)
        dataiter_val = iter(valloader)
        sample_inputs_torch,_ = dataiter_val.next()
        #sample_inputs_torch.to(device)
        sample_inputs_torch = sample_inputs_torch.to(device)


    
    # --- initialize decoder ---
    checkpoint = torch.load(save_dir + 'network_batch' + str(batch_size) + '.pt')
    if decoder_net == 'linGauss':
        from linGaussModel import Decoder
        decoder = Decoder(x_dim, z_dim).to(device)
    elif decoder_net == 'nonLinGauss':
        from VAEModel import Decoder_samp
        z_num_samp = No
        decoder = Decoder_samp(x_dim,z_dim).to(device)
    elif decoder_net == 'VAE':
        from VAEModel import Decoder, Encoder
        encoder = Encoder(x_dim,z_dim).to(device)
        encoder.apply(weights_init_normal)
        decoder = Decoder(x_dim, z_dim).to(device)
        decoder.apply(weights_init_normal)
    elif decoder_net == 'VAE_CNN':
        from VAEModel_CNN import Decoder, Encoder
        encoder = Encoder(z_dim,c_dim,img_size).to(device)
        decoder = Decoder(z_dim,c_dim,img_size).to(device)
        encoder.load_state_dict(checkpoint['model_state_dict_encoder'])
        decoder.load_state_dict(checkpoint['model_state_dict_decoder'])
    elif decoder_net == 'VAE_Imagenet':
        from VAEModel_CNN_imagenet import Decoder, Encoder
        encoder = Encoder(z_dim,c_dim,img_size).to(device)
        decoder = Decoder(z_dim,c_dim,img_size).to(device)
        encoder.load_state_dict(checkpoint['model_state_dict_encoder'])
        decoder.load_state_dict(checkpoint['model_state_dict_decoder'])
    else:
        print("Error: decoder_net should be one of: linGauss nonLinGauss VAE")
    
    # --- initialize classifier ---
    if classifier_net == 'oneHyperplane':        
        from hyperplaneClassifierModel import OneHyperplaneClassifier
        classifier = OneHyperplaneClassifier(x_dim, y_dim, Pw1_torch, ksig=5.).to(device)
        classifier.apply(weights_init_normal)
    elif classifier_net == 'twoHyperplane':
        from hyperplaneClassifierModel import TwoHyperplaneClassifier
        classifier = TwoHyperplaneClassifier(x_dim, y_dim, Pw1_torch, Pw2_torch, ksig=5.).to(device)
        classifier.apply(weights_init_normal)
    elif classifier_net == 'cnn':
        from cnnClassifierModel import CNN
        classifier = CNN(y_dim).to(device)
        batch_orig = 64
        checkpoint = torch.load('./mnist_batch64_lr0.1_class38/network_batch' + str(batch_orig) + '.pt')
        #checkpoint = torch.load('/home/mnorko/Documents/Tensorflow/causal_vae/results/mnist_batch64_lr0.1_class149/network_batch' + str(batch_orig) + '.pt')
        classifier.load_state_dict(checkpoint['model_state_dict_classifier'])
    elif classifier_net == 'cnn_imagenet':
        from cnnImageNetClassifierModel import CNN
        classImagenetIdx = [366,340]
        classifier_model = models.vgg16_bn(pretrained=True)
        classifier = CNN(classifier_model,classImagenetIdx).to(device)
    else:
        print("Error: classifier should be one of: oneHyperplane twoHyperplane")
        
    ## 
    
    
    sample_latent,mu,var = encoder(sample_inputs_torch)
    sample_inputs_torch = sample_inputs_torch.permute(0,2,3,1)
    sample_inputs_np = sample_inputs_torch.detach().cpu().numpy()
    sample_img = decoder(sample_latent)
    sample_img = sample_img.permute(0,2,3,1)
    sample_img_np = sample_img.detach().cpu().numpy()
    sio.savemat(save_dir + 'testInOut.mat',{'sample_inputs':sample_inputs_np,'sample_latent':sample_latent.detach().cpu().numpy(),'sample_out':sample_img_np})
    teX, teY, teidx = load_mnist_classSelect('test', class_use, newClass)
elif opt.model == 'fmnist':
    trX, trY, tridx = load_fashion_mnist_classSelect('train', class_use,
                                                     newClass)
    vaX, vaY, vaidx = load_fashion_mnist_classSelect('val', class_use,
                                                     newClass)
    teX, teY, teidx = load_fashion_mnist_classSelect('test', class_use,
                                                     newClass)

batch_idxs = len(trX) // opt.batch_size
batch_idxs_val = len(vaX) // test_size
ce_loss = nn.CrossEntropyLoss()

from cnnClassifierModel import CNN

classifier = CNN(y_dim).to(device)
optimizer = torch.optim.SGD(classifier.parameters(),
                            lr=lr,
                            momentum=opt.momentum)
scheduler = StepLR(optimizer, step_size=1, gamma=opt.gamma)

loss_total = np.zeros((opt.epochs * batch_idxs))
test_loss_total = np.zeros((opt.epochs))
percent_correct = np.zeros((opt.epochs))
start_time = time.time()
counter = 0
for epoch in range(0, opt.epochs):
    for idx in range(0, batch_idxs):
        batch_labels = torch.from_numpy(trY[idx * opt.batch_size:(idx + 1) *
                                            opt.batch_size]).long().to(device)
        batch_images = trX[idx * opt.batch_size:(idx + 1) * opt.batch_size]
示例#5
0
        decoder.load_state_dict(checkpoint['model_state_dict_decoder'])
    else:
        print("Error: decoder_net should be one of: linGauss nonLinGauss VAE")
    
    # --- initialize classifier ---
    if classifier_net == 'oneHyperplane':        
        from hyperplaneClassifierModel import OneHyperplaneClassifier
        classifier = OneHyperplaneClassifier(x_dim, y_dim, Pw1_torch, ksig=5.).to(device)
        classifier.apply(weights_init_normal)
    elif classifier_net == 'twoHyperplane':
        from hyperplaneClassifierModel import TwoHyperplaneClassifier
        classifier = TwoHyperplaneClassifier(x_dim, y_dim, Pw1_torch, Pw2_torch, ksig=5.).to(device)
        classifier.apply(weights_init_normal)
    elif classifier_net == 'cnn':
        from cnnClassifierModel import CNN
        classifier = CNN(y_dim).to(device)
        batch_orig = 64
        checkpoint = torch.load('/home/mnorko/Documents/Tensorflow/causal_vae/results/mnist_batch64_lr0.1_class38/network_batch' + str(batch_orig) + '.pt')
        classifier.load_state_dict(checkpoint['model_state_dict_classifier'])
    else:
        print("Error: classifier should be one of: oneHyperplane twoHyperplane")
        
    ## 
    
    
    if decoder_net == 'linGauss':
    	What = Variable(torch.mul(torch.randn(x_dim, z_dim, dtype=torch.float),0.5), requires_grad=True)
    else:
        What = None