Exemplo n.º 1
0
 def informationFlow(self, Nalpha=None, Nbeta=None):
     ceparams = self.ceparams.copy()
     if Nalpha is not None:
         ceparams['Nalpha'] = Nalpha
     if Nbeta is not None:
         ceparams['Nbeta'] = Nbeta
     negI, _ = causaleffect.joint_uncond(
         ceparams, self.decoder, self.classifier, self.device)
     return -1. * negI
Exemplo n.º 2
0
def CVAE(
        model="mnist_VAE_CNN",  # currently not used
        steps=20000,
        batch_size=100,
        z_dim=6,
        z_dim_true=4,
        x_dim=10,
        y_dim=1,
        alpha_dim=4,
        ntrain=5000,
        No=15,
        Ni=15,
        lam_ML=0.000001,
        gamma=0.001,
        lr=0.0001,
        b1=0.5,
        b2=0.999,
        use_ce=True,
        objective="IND_UNCOND",
        decoder_net="VAE_CNN",  # options are ['linGauss','nonLinGauss','VAE','VAE_CNN','VAE_Imagenet','VAE_fMNIST']
        classifier_net="cnn",  # options are ['oneHyperplane','twoHyperplane','cnn','cnn_imagenet','cnn_fmnist']
        data_type="mnist",  # options are ["2dpts","mnist","imagenet","fmnist"]
        break_up_ce=True,  # Whether or not to break up the forward passes of the network based on alphas
        randseed=None,
        save_output=False,
        debug_level=2,
        debug_plot=False,
        save_plot=False,
        c_dim=1,
        img_size=28):

    # initialization
    params = {
        "steps": steps,
        "batch_size": batch_size,
        "z_dim": z_dim,
        "z_dim_true": z_dim_true,
        "x_dim": x_dim,
        "y_dim": y_dim,
        "alpha_dim": alpha_dim,
        "ntrain": ntrain,
        "No": No,
        "Ni": Ni,
        "lam_ML": lam_ML,
        "gamma": gamma,
        "lr": lr,
        "b1": b1,
        "b2": b2,
        "use_ce": use_ce,
        "objective": objective,
        "decoder_net": decoder_net,
        "classifier_net": classifier_net,
        "data_type": data_type,
        "break_up_ce": break_up_ce,
        "randseed": randseed,
        "save_output": save_output,
        "debug_level": debug_level,
        'c_dim': c_dim,
        'img_size': img_size
    }
    params["data_std"] = 2.
    if debug_level > 0:
        print("Parameters:")
        print(params)

    # Initialize arrays for storing performance data
    debug = {}
    debug["loss"] = np.zeros((steps))
    debug["loss_ce"] = np.zeros((steps))
    debug["loss_nll"] = np.zeros((steps))
    debug["loss_nll_logdet"] = np.zeros((steps))
    debug["loss_nll_quadform"] = np.zeros((steps))
    debug["loss_nll_mse"] = np.zeros((steps))
    debug["loss_nll_kld"] = np.zeros((steps))
    if decoder_net == 'linGauss':
        for i in range(params["z_dim"]):
            for j in range(params["z_dim_true"]):
                debug["cossim_w%dwhat%d" % (j + 1, i + 1)] = np.zeros((steps))
            for j in range(i + 1, params["z_dim"]):
                debug["cossim_what%dwhat%d" % (i + 1, j + 1)] = np.zeros(
                    (steps))
    if save_plot: frames = []
    if data_type == 'mnist' or data_type == 'fmnist':
        class_use = np.array([0, 3, 4])
        class_use_str = np.array2string(class_use)
        y_dim = class_use.shape[0]
        newClass = range(0, y_dim)
        save_dir = '/home/mnorko/Documents/Tensorflow/causal_vae/results/fmnist_class034/' + data_type + '_' + objective + '_zdim' + str(
            z_dim) + '_alpha' + str(alpha_dim) + '_No' + str(No) + '_Ni' + str(
                Ni) + '_lam' + str(lam_ML) + '_class' + class_use_str[1:(
                    len(class_use_str) - 1):2] + '/'
    else:
        save_dir = '/home/mnorko/Documents/Tensorflow/causal_vae/results/imagenet/' + data_type + '_' + objective + '_zdim' + str(
            z_dim) + '_alpha' + str(alpha_dim) + '_No' + str(No) + '_Ni' + str(
                Ni) + '_lam' + str(lam_ML) + '_cont_kl0.1/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if data_type == '2dpts':
        break_up_ce = False
        params['break_up_ce'] = False

    # seed random number generator
    if randseed is not None:
        if debug_level > 0:
            print('Setting random seed to ' + str(randseed) + '.')
        np.random.seed(randseed)
        torch.manual_seed(randseed)
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Generate data
    if data_type == '2dpts':
        # --- construct projection matrices ---
        # 'true' orthogonal columns used to generate data from latent factors
        #Wsquare = sp.linalg.orth(np.random.rand(x_dim,x_dim))
        Wsquare = np.identity(x_dim)
        W = Wsquare[:, :z_dim_true]
        # 1st column of W
        w1 = np.expand_dims(W[:, 0], axis=1)
        # 2nd column of W
        w2 = np.expand_dims(W[:, 1], axis=1)
        # form projection matrices
        Pw1 = util.formProjMat(w1)
        Pw2 = util.formProjMat(w2)
        # convert to torch matrices
        Pw1_torch = torch.from_numpy(Pw1).float()
        Pw2_torch = torch.from_numpy(Pw2).float()

        # --- construct data ---
        # ntrain instances of alpha and x
        Alpha = params["data_std"] * np.random.randn(ntrain, z_dim_true)
        X = np.matmul(Alpha, W.T)
    elif data_type == 'mnist':
        test_size = 64
        X, Y, tridx = load_mnist_classSelect('train', class_use, newClass)
        vaX, vaY, vaidx = load_mnist_classSelect('val', class_use, newClass)
        sample_inputs = vaX[0:test_size]
        sample_inputs_torch = torch.from_numpy(sample_inputs)
        sample_inputs_torch = sample_inputs_torch.permute(0, 3, 1,
                                                          2).float().to(device)
        ntrain = X.shape[0]
    elif data_type == 'fmnist':
        test_size = 64
        X, Y, tridx = load_fashion_mnist_classSelect('train', class_use,
                                                     newClass)
        vaX, vaY, vaidx = load_fashion_mnist_classSelect(
            'val', class_use, newClass)
        sample_inputs = vaX[0:test_size]
        sample_inputs_torch = torch.from_numpy(sample_inputs)
        sample_inputs_torch = sample_inputs_torch.permute(0, 3, 1,
                                                          2).float().to(device)
        ntrain = X.shape[0]
    elif data_type == 'svhn':
        X, Y, tridx = load_svhn_classSelect('train', class_use, newClass)
        vaX, vaY, vaidx = load_svhn_classSelect('val', class_use, newClass)
        sample_inputs = vaX[0:test_size]
        sample_inputs_torch = torch.from_numpy(sample_inputs)
        sample_inputs_torch = sample_inputs_torch.float().to(device)
        ntrain = X.shape[0]
    elif data_type == 'imagenet':
        transform_train = transforms.Compose([
            transforms.RandomCrop(128, padding=4, pad_if_needed=True),
            transforms.RandomHorizontalFlip(),
            transforms.Resize([32, 32]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        transform_test = transforms.Compose([
            transforms.CenterCrop(128),
            transforms.Resize([32, 32]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        file_open = open('imagenet_zebra_gorilla.pkl', 'rb')
        train_fileType_2 = pickle.load(file_open)
        train_imgName_2 = pickle.load(file_open)
        train_imgLabel_2 = pickle.load(file_open)

        val_fileType_2 = pickle.load(file_open)
        val_imgName_2 = pickle.load(file_open)
        val_imgLabel_2 = pickle.load(file_open)

        file_open.close()
        train_set = Imagenet_Gor_Zeb(train_imgName_2,
                                     train_imgLabel_2,
                                     train_fileType_2,
                                     transforms=transform_train)

        trainloader = torch.utils.data.DataLoader(train_set,
                                                  batch_size=batch_size,
                                                  shuffle=True)
        test_size = 16
        val_set = Imagenet_Gor_Zeb(val_imgName_2,
                                   val_imgLabel_2,
                                   val_fileType_2,
                                   transforms=transform_test)
        valloader = torch.utils.data.DataLoader(val_set,
                                                batch_size=test_size,
                                                shuffle=True)

        dataiter = iter(trainloader)
        dataiter_val = iter(valloader)
        sample_inputs_torch, _ = dataiter_val.next()
        #sample_inputs_torch.to(device)
        sample_inputs_torch = sample_inputs_torch.to(device)

    # --- initialize decoder ---
    if decoder_net == 'linGauss':
        from linGaussModel import Decoder
        decoder = Decoder(x_dim, z_dim).to(device)
    elif decoder_net == 'nonLinGauss':
        from VAEModel import Decoder_samp
        z_num_samp = No
        decoder = Decoder_samp(x_dim, z_dim).to(device)
    elif decoder_net == 'VAE':
        from VAEModel import Decoder, Encoder
        encoder = Encoder(x_dim, z_dim).to(device)
        encoder.apply(weights_init_normal)
        decoder = Decoder(x_dim, z_dim).to(device)
        decoder.apply(weights_init_normal)
    elif decoder_net == 'VAE_CNN' or 'VAE_fMNIST':
        from VAEModel_CNN import Decoder, Encoder
        encoder = Encoder(z_dim, c_dim, img_size).to(device)
        encoder.apply(weights_init_normal)
        decoder = Decoder(z_dim, c_dim, img_size).to(device)
        decoder.apply(weights_init_normal)
    elif decoder_net == 'VAE_Imagenet':
        checkpoint = torch.load(
            '/home/mnorko/Documents/Tensorflow/causal_vae/results/imagenet/imagenet_JOINT_UNCOND_zdim40_alpha0_No20_Ni1_lam0.001/'
            + 'network_batch' + str(batch_size) + '.pt')
        from VAEModel_CNN_imagenet import Decoder, Encoder
        encoder = Encoder(z_dim, c_dim, img_size).to(device)
        decoder = Decoder(z_dim, c_dim, img_size).to(device)
        encoder.load_state_dict(checkpoint['model_state_dict_encoder'])
        decoder.load_state_dict(checkpoint['model_state_dict_decoder'])
    else:
        print("Error: decoder_net should be one of: linGauss nonLinGauss VAE")

    # --- initialize classifier ---
    if classifier_net == 'oneHyperplane':
        from hyperplaneClassifierModel import OneHyperplaneClassifier
        classifier = OneHyperplaneClassifier(x_dim, y_dim, Pw1_torch,
                                             ksig=5.).to(device)
        classifier.apply(weights_init_normal)
    elif classifier_net == 'twoHyperplane':
        from hyperplaneClassifierModel import TwoHyperplaneClassifier
        classifier = TwoHyperplaneClassifier(x_dim,
                                             y_dim,
                                             Pw1_torch,
                                             Pw2_torch,
                                             ksig=5.).to(device)
        classifier.apply(weights_init_normal)
    elif classifier_net == 'cnn':
        from cnnClassifierModel import CNN
        classifier = CNN(y_dim).to(device)
        batch_orig = 64
        checkpoint = torch.load('./mnist_batch64_lr0.1_class38/network_batch' +
                                str(batch_orig) + '.pt')
        #checkpoint = torch.load('/home/mnorko/Documents/Tensorflow/causal_vae/results/mnist_batch64_lr0.1_class149/network_batch' + str(batch_orig) + '.pt')
        classifier.load_state_dict(checkpoint['model_state_dict_classifier'])
    elif classifier_net == 'cnn_fmnist':
        from cnnClassifierModel import CNN
        classifier = CNN(y_dim).to(device)
        batch_orig = 64
        checkpoint = torch.load(
            './fmnist_batch64_lr0.1_class034/network_batch' + str(batch_orig) +
            '.pt')
        #checkpoint = torch.load('/home/mnorko/Documents/Tensorflow/causal_vae/results/mnist_batch64_lr0.1_class149/network_batch' + str(batch_orig) + '.pt')
        classifier.load_state_dict(checkpoint['model_state_dict_classifier'])
    elif classifier_net == 'cnn_imagenet':
        from cnnImageNetClassifierModel import CNN
        classImagenetIdx = [366, 340]
        classifier_model = models.vgg16_bn(pretrained=True)
        classifier = CNN(classifier_model, classImagenetIdx).to(device)
    else:
        print(
            "Error: classifier should be one of: oneHyperplane twoHyperplane")

    ##

    if decoder_net == 'linGauss':
        What = Variable(torch.mul(torch.randn(x_dim, z_dim, dtype=torch.float),
                                  0.5),
                        requires_grad=True)
    else:
        What = None

    # --- specify optimizer ---
    # NOTE: we only include the decoder parameters in the optimizer
    # because we don't want to update the classifier parameters
    if decoder_net == 'VAE' or decoder_net == 'VAE_CNN' or decoder_net == 'VAE_Imagenet' or decoder_net == 'VAE_fMNIST':
        params_use = list(decoder.parameters()) + list(encoder.parameters())
    else:
        params_use = list(decoder.parameters())
    optimizer_NN = torch.optim.Adam(params_use, lr=lr, betas=(b1, b2))

    # --- train ---
    start_time = time.time()
    for k in range(0, steps):

        # --- reset gradients to zero ---
        # (you always need to do this in pytorch or the gradients are
        # accumulated from one batch to the next)
        optimizer_NN.zero_grad()

        # --- compute negative log likelihood ---
        # randomly subsample batch_size samples of x

        randIdx = np.random.randint(0, ntrain, batch_size)
        if data_type == '2dpts':
            Xbatch = torch.from_numpy(X[randIdx, :]).float()
        elif data_type == 'mnist' or data_type == 'fmnist':
            Xbatch = torch.from_numpy(X[randIdx]).float()
            Xbatch = Xbatch.permute(0, 3, 1, 2)
        elif data_type == 'imagenet':
            try:
                Xbatch, _ = dataiter.next()
            except:
                dataiter = iter(trainloader)
                Xbatch, _ = dataiter.next()
        Xbatch = Xbatch.to(device)
        if decoder_net == 'linGauss':
            nll = loss_functions.linGauss_NLL_loss(Xbatch, What, gamma)
        elif decoder_net == 'nonLinGauss':
            randBatch = torch.from_numpy(np.random.randn(z_num_samp,
                                                         z_dim)).float()
            Xest, Xmu, Xlogvar = decoder(randBatch)
            Xcov = torch.exp(Xlogvar)
            nll = loss_functions.nonLinGauss_NLL_loss(Xbatch, Xmu, Xcov)
        elif decoder_net == 'VAE' or decoder_net == 'VAE_CNN' or decoder_net == 'VAE_Imagenet' or decoder_net == 'VAE_fMNIST':
            latent_out, mu, logvar = encoder(Xbatch)
            Xest = decoder(latent_out)
            nll, nll_mse, nll_kld = loss_functions.VAE_LL_loss(
                Xbatch, Xest, logvar, mu)

        # --- compute mutual information causal effect term ---
        if objective == "IND_UNCOND":
            causalEffect, ceDebug = causaleffect.ind_uncond(params,
                                                            decoder,
                                                            classifier,
                                                            device,
                                                            What=What)
        elif objective == "IND_COND":
            causalEffect, ceDebug = causaleffect.ind_cond(params,
                                                          decoder,
                                                          classifier,
                                                          device,
                                                          What=What)
        elif objective == "JOINT_UNCOND":
            causalEffect, ceDebug = causaleffect.joint_uncond(params,
                                                              decoder,
                                                              classifier,
                                                              device,
                                                              What=What)
        elif objective == "JOINT_COND":
            causalEffect, ceDebug = causaleffect.joint_cond(params,
                                                            decoder,
                                                            classifier,
                                                            device,
                                                            What=What)

        # --- compute gradients ---
        # total loss
        loss = use_ce * causalEffect + lam_ML * nll
        # backward step to compute the gradients
        loss.backward()
        if decoder_net == 'linGauss':
            # update What with the computed gradient
            What.data.sub_(lr * What.grad.data)
            # reset the What gradients to 0
            What.grad.data.zero_()
        else:
            optimizer_NN.step()

        # --- save debug info for this step ---
        debug["loss"][k] = loss.item()
        debug["loss_ce"][k] = causalEffect.item()
        debug["loss_nll"][k] = (lam_ML * nll).item()
        if decoder_net == 'VAE' or decoder_net == 'VAE_CNN' or decoder_net == 'VAE_Imagenet' or decoder_net == 'VAE_fMNIST':
            debug["loss_nll_mse"][k] = (lam_ML * nll_mse).item()
            debug["loss_nll_kld"][k] = (lam_ML * nll_kld).item()
        if debug_level > 1:
            if decoder_net == 'linGauss':
                Wnorm = W / np.linalg.norm(W, axis=0, keepdims=True)
                Whatnorm = What.detach().numpy() / np.linalg.norm(
                    What.detach().numpy(), axis=0, keepdims=True)
                # cosine similarities between columns of W and What
                for i in range(params["z_dim"]):
                    for j in range(params["z_dim_true"]):
                        debug["cossim_w%dwhat%d" %
                              (j + 1, i + 1)][k] = np.matmul(
                                  Wnorm[:, j], Whatnorm[:, i])
                    for j in range(i + 1, params["z_dim"]):
                        debug["cossim_what%dwhat%d" %
                              (i + 1, j + 1)][k] = np.matmul(
                                  Whatnorm[:, i], Whatnorm[:, j])

        # --- print step information ---
        if debug_level > 0:
            print("[Step %d/%d] time: %4.2f  [CE: %g] [ML: %g] [loss: %g]" % \
                  (k, steps, time.time() - start_time, debug["loss_ce"][k],
                   debug["loss_nll"][k], debug["loss"][k]))

        # --- debug plot ---
        if debug_plot and k % 1000 == 0:
            print('Generating plot frame...')
            if data_type == '2dpts':
                # generate samples of p(x | alpha_i = alphahat_i)
                decoded_points = {}
                decoded_points["ai_vals"] = lfplot_aihat_vals
                decoded_points["samples"] = np.zeros(
                    (2, lfplot_nsamp, len(lfplot_aihat_vals), params["z_dim"]))
                for l in range(params["z_dim"]):  # loop over latent dimensions
                    for i, aihat in enumerate(
                            lfplot_aihat_vals):  # loop over fixed aihat values
                        for m in range(
                                lfplot_nsamp):  # loop over samples to generate
                            z = np.random.randn(params["z_dim"])
                            z[l] = aihat
                            x = decoder(
                                torch.from_numpy(z).float(), What, gamma)
                            decoded_points["samples"][:, m, i,
                                                      l] = x.detach().numpy()
                frame = plotting.debugPlot_frame(X, ceDebug["Xhat"], W, What,
                                                 k, steps, debug, params,
                                                 classifier, decoded_points)
                if save_plot:
                    frames.append(frame)
            elif data_type == 'mnist' or data_type == 'imagenet' or data_type == 'fmnist':
                torch.save(
                    {
                        'step': k,
                        'model_state_dict_classifier': classifier.state_dict(),
                        'model_state_dict_encoder': encoder.state_dict(),
                        'model_state_dict_decoder': decoder.state_dict(),
                        'optimizer_state_dict': optimizer_NN.state_dict(),
                        'loss': loss,
                    }, save_dir + 'network_batch' + str(batch_size) + '.pt')
                sample_latent, mu, var = encoder(sample_inputs_torch)
                sample_inputs_torch_new = sample_inputs_torch.permute(
                    0, 2, 3, 1)
                sample_inputs_np = sample_inputs_torch_new.detach().cpu(
                ).numpy()
                sample_img = decoder(sample_latent)
                sample_latent_small = sample_latent[0:10, :]
                imgOut_real, probOut_real, latentOut_real = sweepLatentFactors(
                    sample_latent_small, decoder, classifier, device, img_size,
                    c_dim, y_dim, False)
                rand_latent = torch.from_numpy(np.random.randn(
                    10, z_dim)).float().to(device)
                imgOut_rand, probOut_rand, latentOut_rand = sweepLatentFactors(
                    rand_latent, decoder, classifier, device, img_size, c_dim,
                    y_dim, False)
                samples = sample_img
                samples = samples.permute(0, 2, 3, 1)
                samples = samples.detach().cpu().numpy()
                save_images(samples, [8, 8],
                            '{}train_{:04d}.png'.format(save_dir, k))
                #sio.savemat(save_dir + 'sweepLatentFactors.mat',{'imgOut_real':imgOut_real,'probOut_real':probOut_real,'latentOut_real':latentOut_real,'loss_total':debug["loss"][:k],'loss_ce':debug["loss_ce"][:k],'loss_nll':debug['loss_nll'][:k],'samples_out':samples,'sample_inputs':sample_inputs_np})
                sio.savemat(
                    save_dir + 'sweepLatentFactors.mat', {
                        'imgOut_real': imgOut_real,
                        'probOut_real': probOut_real,
                        'latentOut_real': latentOut_real,
                        'imgOut_rand': imgOut_rand,
                        'probOut_rand': probOut_rand,
                        'latentOut_rand': latentOut_rand,
                        'loss_total': debug["loss"][:k],
                        'loss_ce': debug["loss_ce"][:k],
                        'loss_nll': debug['loss_nll'][:k],
                        'samples_out': samples,
                        'sample_inputs': sample_inputs_np
                    })

    # --- save all debug data ---
    debug["X"] = Xbatch.detach().cpu().numpy()
    if not decoder_net == 'linGauss':
        debug["Xest"] = Xest.detach().cpu().numpy()
    if save_output:
        datestamp = ''.join(
            re.findall(r'\d+',
                       str(datetime.datetime.now())[:10]))
        timestamp = ''.join(
            re.findall(r'\d+',
                       str(datetime.datetime.now())[11:19]))
        #        results_folder = './results/tests_kSig5_lr0001_' + objective + '_lam' \
        #            + str(lam_ML) + '_No' + str(No) + '_Ni' + str(Ni) + '_' \
        #            + datestamp + '_' + timestamp + '/'
        #        if not os.path.exists(results_folder):
        #            os.makedirs(results_folder)
        results_folder = save_dir
        matfilename = 'results_' + datestamp + '_' + timestamp + '.mat'
        sio.savemat(results_folder + matfilename, {
            'params': params,
            'data': debug
        })
        if debug_level > 0:
            print('Finished saving data to ' + matfilename)

    if save_plot:
        print('Saving plot...')
        gif.save(frames, "results.gif", duration=100)
        print('Done!')

    return debug
Exemplo n.º 3
0
def CVAE_test_twohyperplaneVAE(
        model="linGauss_multiHP",  # currently not used
        steps=6000,
        batch_size=100,
        z_dim=4,
        z_dim_true=4,
        x_dim=10,
        y_dim=1,
        alpha_dim=2,
        ntrain=5000,
        No=15,
        Ni=15,
        lam_ML=0.000001,
        gamma=0.001,
        lr=0.0001,
        b1=0.5,
        b2=0.999,
        use_ce=True,
        objective="IND_UNCOND",
        decoder_net="linGauss",
        classifier_net="hyperplane",
        randseed=None,
        save_output=False,
        debug_level=2,
        debug_plot=False,
        save_plot=False):

    # initialization
    params = {
        "steps": steps,
        "batch_size": batch_size,
        "z_dim": z_dim,
        "z_dim_true": z_dim_true,
        "x_dim": x_dim,
        "y_dim": y_dim,
        "alpha_dim": alpha_dim,
        "ntrain": ntrain,
        "No": No,
        "Ni": Ni,
        "lam_ML": lam_ML,
        "gamma": gamma,
        "lr": lr,
        "b1": b1,
        "b2": b2,
        "use_ce": use_ce,
        "objective": objective,
        "decoder_net": decoder_net,
        "classifier_net": classifier_net,
        "randseed": randseed,
        "save_output": save_output,
        "debug_level": debug_level
    }
    params["plot_batchsize"] = 500
    params["data_std"] = 2.
    if debug_level > 0:
        print("Parameters:")
        print(params)

    # Initialize arrays for storing performance data
    debug = {}
    vis_samples = 35
    debug["yhat_min"] = np.zeros((steps))
    debug["loss"] = np.zeros((steps))
    debug["loss_ce"] = np.zeros((steps))
    debug["loss_nll"] = np.zeros((steps))
    debug["loss_nll_logdet"] = np.zeros((steps))
    debug["loss_nll_quadform"] = np.zeros((steps))
    debug["loss_nll_mse"] = np.zeros((steps))
    debug["loss_nll_kld"] = np.zeros((steps))
    debug["What"] = np.zeros((z_dim, x_dim, steps))
    debug["xhat_a1"] = np.zeros((x_dim, vis_samples, steps))
    debug["xhat_a2"] = np.zeros((x_dim, vis_samples, steps))
    debug["Yhat_a1"] = np.zeros((x_dim, vis_samples, steps))
    debug["Yhat_a2"] = np.zeros((x_dim, vis_samples, steps))

    # seed random number generator
    if randseed is not None:
        if debug_level > 0:
            print('Setting random seed to ' + str(randseed) + '.')
        np.random.seed(randseed)
        torch.manual_seed(randseed)

    # --- construct projection matrices ---
    # 'true' orthogonal columns used to generate data from latent factors
    #Wsquare = sp.linalg.orth(np.random.rand(x_dim,x_dim))
    Wsquare = np.identity(x_dim)
    W = Wsquare
    w1 = np.expand_dims(W[:, 0], axis=1)
    w2 = np.expand_dims(W[:, 1], axis=1)
    # form projection matrices
    Pw1 = util.formProjMat(w1)
    Pw2 = util.formProjMat(w2)
    # convert to torch matrices
    Pw1_torch = torch.from_numpy(Pw1).float()
    Pw2_torch = torch.from_numpy(Pw2).float()

    # --- construct data ---
    # ntrain instances of alpha and x
    Alpha = params["data_std"] * np.random.randn(ntrain, z_dim_true)
    X = np.matmul(Alpha, W.T)

    # --- initialize decoder ---
    if decoder_net == 'linGauss':
        from linGaussModel import Decoder
        decoder = Decoder(x_dim, z_dim)
    elif decoder_net == 'nonLinGauss':
        from VAEModel import Decoder_samp
        z_num_samp = No
        decoder = Decoder_samp(x_dim, z_dim)
    elif decoder_net == 'VAE':
        from VAEModel import Decoder, Encoder
        encoder = Encoder(x_dim, z_dim)
        encoder.apply(weights_init_normal)
        decoder = Decoder(x_dim, z_dim)
    else:
        print("Error: decoder_net should be one of: linGauss nonLinGauss VAE")
    decoder.apply(weights_init_normal)

    # --- initialize classifier ---
    if classifier_net == 'oneHyperplane':
        from hyperplaneClassifierModel import OneHyperplaneClassifier
        classifier = OneHyperplaneClassifier(x_dim,
                                             y_dim,
                                             Pw1_torch,
                                             ksig=100.,
                                             a1=w1.reshape((1, 2)))
    elif classifier_net == 'twoHyperplane':
        from hyperplaneClassifierModel import TwoHyperplaneClassifier
        classifier = TwoHyperplaneClassifier(x_dim,
                                             y_dim,
                                             Pw1_torch,
                                             Pw2_torch,
                                             ksig=100.,
                                             a1=w1.reshape((1, 2)),
                                             a2=w2.reshape((1, 2)))
    else:
        print(
            "Error: classifier should be one of: oneHyperplane twoHyperplane")
    classifier.apply(weights_init_normal)

    if decoder_net == 'linGauss':
        What = Variable(torch.mul(torch.randn(x_dim, z_dim, dtype=torch.float),
                                  0.5),
                        requires_grad=True)
    else:
        What = None

    # --- specify optimizer ---
    # NOTE: we only include the decoder parameters in the optimizer
    # because we don't want to update the classifier parameters
    if decoder_net == 'VAE':
        params_use = list(decoder.parameters()) + list(encoder.parameters())
    else:
        params_use = list(decoder.parameters())
    optimizer_NN = torch.optim.Adam(params_use, lr=lr, betas=(b1, b2))

    # --- train ---
    start_time = time.time()
    for k in range(0, steps):

        # --- reset gradients to zero ---
        # (you always need to do this in pytorch or the gradients are
        # accumulated from one batch to the next)
        optimizer_NN.zero_grad()

        # --- compute negative log likelihood ---
        # randomly subsample batch_size samples of x
        randIdx = np.random.randint(0, ntrain, batch_size)
        Xbatch = torch.from_numpy(X[randIdx, :]).float()
        if decoder_net == 'linGauss':
            nll = loss_functions.linGauss_NLL_loss(Xbatch, What, gamma)
        elif decoder_net == 'nonLinGauss':
            randBatch = torch.from_numpy(np.random.randn(z_num_samp,
                                                         z_dim)).float()
            Xest, Xmu, Xlogvar = decoder(randBatch)
            Xcov = torch.exp(Xlogvar)
            nll = loss_functions.nonLinGauss_NLL_loss(Xbatch, Xmu, Xcov)
        elif decoder_net == 'VAE':
            latent_out, mu, logvar = encoder(Xbatch)
            Xest = decoder(latent_out)
            nll, nll_mse, nll_kld = loss_functions.VAE_LL_loss(
                Xbatch, Xest, logvar, mu)

        # --- compute mutual information causal effect term ---
        if objective == "IND_UNCOND":
            causalEffect, ceDebug = causaleffect.ind_uncond(params,
                                                            decoder,
                                                            classifier,
                                                            What=What)
        elif objective == "IND_COND":
            causalEffect, ceDebug = causaleffect.ind_cond(params,
                                                          decoder,
                                                          classifier,
                                                          What=What)
        elif objective == "JOINT_UNCOND":
            causalEffect, ceDebug = causaleffect.joint_uncond(params,
                                                              decoder,
                                                              classifier,
                                                              What=What)
        elif objective == "JOINT_COND":
            causalEffect, ceDebug = causaleffect.joint_cond(params,
                                                            decoder,
                                                            classifier,
                                                            What=What)
        yhat_np = ceDebug["yhat"].detach().numpy()

        # --- compute gradients ---
        # total loss
        if use_ce:
            loss = causalEffect + lam_ML * nll
        else:
            loss = lam_ML * nll
        # backward step to compute the gradients
        loss.backward()
        if decoder_net == 'linGauss':
            # update What with the computed gradient
            What.data.sub_(lr * What.grad.data)
            # reset the What gradients to 0
            What.grad.data.zero_()
        else:
            optimizer_NN.step()

        # --- save debug info for this step ---
        debug["yhat_min"][k] = yhat_np.min()
        # components of objective
        debug["loss"][k] = loss.detach().numpy()
        debug["loss_ce"][k] = causalEffect.detach().numpy()
        debug["loss_nll"][k] = (lam_ML * nll).detach().numpy()
        if debug_level > 1:  # and k == steps-1:
            if decoder_net == 'linGauss':
                debug["What"][:, :, k] = What.detach().numpy()
            elif decoder_net == 'VAE':
                debug["loss_nll_mse"][k] = (lam_ML * nll_mse).detach().numpy()
                debug["loss_nll_kld"][k] = (lam_ML * nll_kld).detach().numpy()
                v_sweep = np.linspace(-5., 5., vis_samples)
                # samples Yhat | x[1], x[2], x[3]
                for ix1, x1 in enumerate(v_sweep):
                    for ix2, x2 in enumerate(v_sweep):
                        xs = np.zeros((250, 3))
                        xs[:, 0] = x1
                        xs[:, 1] = x2
                        xs[:, 2] = params["data_std"] * np.random.randn(250)
                        yhats = classifier(torch.from_numpy(xs).float())[0]
                        debug["yhat_x1x2"][ix1, ix2, k] = np.mean(
                            yhats.detach().numpy()[:, 0])
                for ix1, x1 in enumerate(v_sweep):
                    for ix3, x3 in enumerate(v_sweep):
                        xs = np.zeros((250, 3))
                        xs[:, 0] = x1
                        xs[:, 1] = params["data_std"] * np.random.randn(250)
                        xs[:, 2] = x3
                        yhats = classifier(torch.from_numpy(xs).float())[0]
                        debug["yhat_x1x3"][ix1, ix3, k] = np.mean(
                            yhats.detach().numpy()[:, 0])
                for ix2, x2 in enumerate(v_sweep):
                    for ix3, x3 in enumerate(v_sweep):
                        xs = np.zeros((250, 3))
                        xs[:, 0] = params["data_std"] * np.random.randn(250)
                        xs[:, 1] = x2
                        xs[:, 2] = x3
                        yhats = classifier(torch.from_numpy(xs).float())[0]
                        debug["yhat_x2x3"][ix2, ix3, k] = np.mean(
                            yhats.detach().numpy()[:, 0])
                # samples x | alpha[1], alpha[2], beta
                for ia1, a1 in enumerate(v_sweep):
                    for ia2, a2 in enumerate(v_sweep):
                        zs = np.zeros((250, 3))
                        zs[:, 0] = a1
                        zs[:, 1] = a2
                        zs[:, 2] = np.random.randn(250)
                        xs = decoder(
                            torch.from_numpy(zs).float()).detach().numpy()
                        debug["x1_a1a2"][ia1, ia2, k] = np.mean(xs[:, 0])
                        debug["x2_a1a2"][ia1, ia2, k] = np.mean(xs[:, 1])
                        debug["x3_a1a2"][ia1, ia2, k] = np.mean(xs[:, 2])
                for ia1, a1 in enumerate(v_sweep):
                    for ib, b in enumerate(v_sweep):
                        zs = np.zeros((250, 3))
                        zs[:, 0] = a1
                        zs[:, 1] = np.random.randn(250)
                        zs[:, 2] = b
                        xs = decoder(
                            torch.from_numpy(zs).float()).detach().numpy()
                        debug["x1_a1b"][ia1, ib, k] = np.mean(xs[:, 0])
                        debug["x2_a1b"][ia1, ib, k] = np.mean(xs[:, 1])
                        debug["x3_a1b"][ia1, ib, k] = np.mean(xs[:, 2])
                for ia2, a2 in enumerate(v_sweep):
                    for ib, b in enumerate(v_sweep):
                        zs = np.zeros((250, 3))
                        zs[:, 0] = np.random.randn(250)
                        zs[:, 1] = a2
                        zs[:, 2] = b
                        xs = decoder(
                            torch.from_numpy(zs).float()).detach().numpy()
                        debug["x1_a2b"][ia2, ib, k] = np.mean(xs[:, 0])
                        debug["x2_a2b"][ia2, ib, k] = np.mean(xs[:, 1])
                        debug["x3_a2b"][ia2, ib, k] = np.mean(xs[:, 2])
                # samples yhat | alpha[1], alpha[2], beta
                for ia1, a1 in enumerate(v_sweep):
                    for ia2, a2 in enumerate(v_sweep):
                        zs = np.zeros((250, 3))
                        zs[:, 0] = a1
                        zs[:, 1] = a2
                        zs[:, 2] = params["data_std"] * np.random.randn(250)
                        xs = decoder(torch.from_numpy(zs).float())
                        yhats = classifier(xs)[0]
                        debug["yhat_a1a2"][ia1, ia2, k] = np.mean(
                            yhats.detach().numpy()[:, 0])
                for ia1, a1 in enumerate(v_sweep):
                    for ib, b in enumerate(v_sweep):
                        zs = np.zeros((250, 3))
                        zs[:, 0] = a1
                        zs[:, 1] = np.random.randn(250)
                        zs[:, 2] = b
                        xs = decoder(torch.from_numpy(zs).float())
                        yhats = classifier(xs)[0]
                        debug["yhat_a1b"][ia1, ib, k] = np.mean(
                            yhats.detach().numpy()[:, 0])
                for ia2, a2 in enumerate(v_sweep):
                    for ib, b in enumerate(v_sweep):
                        zs = np.zeros((250, 3))
                        zs[:, 0] = np.random.randn(250)
                        zs[:, 1] = a2
                        zs[:, 2] = b
                        xs = decoder(torch.from_numpy(zs).float())
                        yhats = classifier(xs)[0]
                        debug["yhat_a2b"][ia2, ib, k] = np.mean(
                            yhats.detach().numpy()[:, 0])
                # samples x | alpha[i]
                for ia1, a1 in enumerate(range(-3, 4)):
                    zs = np.zeros((250, 3))
                    zs[:, 0] = a1
                    zs[:, 1] = np.random.randn(250)
                    zs[:, 2] = np.random.randn(250)
                    xs = decoder(torch.from_numpy(zs).float())
                    debug["xhat_a1"][ia1, :, :,
                                     k] = xs.detach().numpy().transpose()
                for ia2, a2 in enumerate(range(-3, 4)):
                    zs = np.zeros((250, 3))
                    zs[:, 0] = np.random.randn(250)
                    zs[:, 1] = a2
                    zs[:, 2] = np.random.randn(250)
                    xs = decoder(torch.from_numpy(zs).float())
                    debug["xhat_a2"][ia2, :, :,
                                     k] = xs.detach().numpy().transpose()

        # --- print step information ---
        if debug_level > 0:
            print("[Step %d/%d] time: %4.2f  [CE: %g] [ML: %g] [loss: %g]" % \
                  (k, steps, time.time() - start_time, debug["loss_ce"][k],
                   debug["loss_nll"][k], debug["loss"][k]))

        # --- debug plot ---
        if debug_plot and k % 500 == 0:
            print('Generating plot frame...')
            # generate samples of p(x | alpha_i = alphahat_i)
            decoded_points = {}
            decoded_points["ai_vals"] = lfplot_aihat_vals
            decoded_points["samples"] = np.zeros(
                (2, lfplot_nsamp, len(lfplot_aihat_vals), params["z_dim"]))
            for l in range(params["z_dim"]):  # loop over latent dimensions
                for i, aihat in enumerate(
                        lfplot_aihat_vals):  # loop over fixed aihat values
                    for m in range(
                            lfplot_nsamp):  # loop over samples to generate
                        z = np.random.randn(params["z_dim"])
                        z[l] = aihat
                        x = decoder(torch.from_numpy(z).float(), What, gamma)
                        decoded_points["samples"][:, m, i,
                                                  l] = x.detach().numpy()
            frame = plotting.debugPlot_frame(X, ceDebug["Xhat"], W, What, k,
                                             steps, debug, params, classifier,
                                             decoded_points)
            if save_plot:
                frames.append(frame)

    # --- save all debug data ---
    debug["X"] = X
    if save_output:
        datestamp = ''.join(
            re.findall(r'\d+',
                       str(datetime.datetime.now())[:10]))
        timestamp = ''.join(
            re.findall(r'\d+',
                       str(datetime.datetime.now())[11:19]))
        results_folder = './results/tests_kSig5_lr0001_' + objective + '_lam' \
            + str(lam_ML) + '_No' + str(No) + '_Ni' + str(Ni) + '_' \
            + datestamp + '_' + timestamp + '/'
        if not os.path.exists(results_folder):
            os.makedirs(results_folder)
        matfilename = 'results_' + datestamp + '_' + timestamp + '.mat'
        sio.savemat(results_folder + matfilename, {
            'params': params,
            'data': debug
        })
        if debug_level > 0:
            print('Finished saving data to ' + matfilename)

    if save_plot:
        print('Saving plot...')
        gif.save(frames, "results.gif", duration=100)
        print('Done!')

    return debug, params
Exemplo n.º 4
0
    def train(self, X, K, L,
              steps = 50000,
              Nalpha = 50,
              Nbeta = 50,
              lam = 0.0001,
              causal_obj = 'JOINT_UNCOND',
              batch_size = 100,
              lr = 0.0001,
              b1 = 0.5,
              b2 = 0.999,
              use_ce = True):
    
        # initialize
        self.K = K
        self.L = L
        ntrain = X.shape[0]
        sample_input = torch.from_numpy(X[0]).unsqueeze(0).float().permute(0,3,1,2)
        M = self.classifier(sample_input.to(self.device))[0].shape[1]
        self.train_params = {
                 'K'                 : K,
                 'L'                 : L,
                 'steps'             : steps,
                 'Nalpha'            : Nalpha,
                 'Nbeta'             : Nbeta,
                 'lambda'            : lam,
                 'causal_obj'        : causal_obj,
                 'batch_size'        : batch_size,
                 'lr'                : lr,
                 'b1'                : b1,
                 'b2'                : b2,
                 'use_ce'            : use_ce}
        self.ceparams = {
                  'Nalpha'           : Nalpha,
                  'Nbeta'            : Nbeta,
                  'K'                : K,
                  'L'                : L,
                  'z_dim'            : K+L,
                  'M'                : M}
        debug = {'loss'              : np.zeros((steps)),
                 'loss_ce'           : np.zeros((steps)),
                 'loss_nll'          : np.zeros((steps)),
                 'loss_nll_logdet'   : np.zeros((steps)),
                 'loss_nll_quadform' : np.zeros((steps)),
                 'loss_nll_mse'      : np.zeros((steps)),
                 'loss_nll_kld'      : np.zeros((steps))}

        # initialize for training
        opt_params = list(self.decoder.parameters()) + list(self.encoder.parameters())
        self.opt = torch.optim.Adam(opt_params, lr=lr, betas=(b1, b2))
        start_time = time.time()

        # training loop
        for k in range(0, steps):
        
            # reset gradient
            self.opt.zero_grad()

            # compute negative log-likelihood
            randIdx = np.random.randint(0, ntrain, batch_size)
            Xbatch = torch.from_numpy(X[randIdx]).float().permute(0,3,1,2).to(self.device)
            z, mu, logvar = self.encoder(Xbatch)
            Xhat = self.decoder(z)
            nll, nll_mse, nll_kld = loss_functions.VAE_LL_loss(Xbatch, Xhat, logvar, mu)

            # compute causal effect
            if causal_obj == 'IND_UNCOND':
                causalEffect, ceDebug = causaleffect.ind_uncond(
                    self.ceparams, self.decoder, self.classifier, self.device)
            elif causal_obj == 'IND_COND':
                causalEffect, ceDebug = causaleffect.ind_cond(
                    self.ceparams, self.decoder, self.classifier, self.device)
            elif causal_obj == 'JOINT_UNCOND':
                causalEffect, ceDebug = causaleffect.joint_uncond(
                    self.ceparams, self.decoder, self.classifier, self.device)
            elif causal_obj == 'JOINT_COND':
                causalEffect, ceDebug = causaleffect.joint_cond(
                    self.ceparams, self.decoder, self.classifier, self.device)
            else:
                print('Invalid causal objective!')
            
            # compute gradient
            loss = use_ce*causalEffect + lam*nll
            loss.backward()
            self.opt.step()

            # save debug info for this step
            debug['loss'][k] = loss.item()
            debug['loss_ce'][k] = causalEffect.item()
            debug['loss_nll'][k] = (lam*nll).item()
            debug['loss_nll_mse'][k] = (lam*nll_mse).item()
            debug['loss_nll_kld'][k] = (lam*nll_kld).item()
            if self.params['debug_print']:
                print("[Step %d/%d] time: %4.2f  [CE: %g] [ML: %g] [loss: %g]" % \
                      (k+1, steps, time.time() - start_time, debug['loss_ce'][k],
                      debug['loss_nll'][k], debug['loss'][k]))
            if self.params['save_output'] and k % 1000 == 0:
                torch.save({
                    'step': k,
                    'model_state_dict_classifier' : self.classifier.state_dict(),
                    'model_state_dict_encoder' : self.encoder.state_dict(),
                    'model_state_dict_decoder' : self.decoder.state_dict(),
                    'optimizer_state_dict' : self.opt.state_dict(),
                    'loss' : loss,
                    }, '%s_batch_%d.pt' % \
                    (self.params['save_dir'], self.params['batch_size']))
        
        # save/return debug data from entire training run
        debug['Xbatch'] = Xbatch.detach().cpu().numpy()
        debug['Xhat'] = Xhat.detach().cpu().numpy()
        if self.params['save_output']:
            datestamp = ''.join(re.findall(r'\d+', str(datetime.datetime.now())[:10]))
            timestamp = ''.join(re.findall(r'\d+', str(datetime.datetime.now())[11:19]))
            matfilename = 'results_' + datestamp + '_' + timestamp + '.mat'
            sio.savemat(save_dir + matfilename, {'params' : params, 'data' : debug})
            if self.params['debug_print']:
                print('Finished saving data to ' + matfilename)
        return debug
 what2 = np.array([[np.cos(theta_alpha2)], [np.sin(theta_alpha2)]])
 What = torch.from_numpy(np.hstack((what1, what2))).float()
 # sample-based estimate of causal effect
 nce_iu, info_iu = causaleffect.ind_uncond(params,
                                           decoder,
                                           classifier,
                                           device,
                                           What=What)
 nce_ic, _ = causaleffect.ind_cond(params,
                                   decoder,
                                   classifier,
                                   device,
                                   What=What)
 nce_ju, _ = causaleffect.joint_uncond(params,
                                       decoder,
                                       classifier,
                                       device,
                                       What=What)
 nce_jc, _ = causaleffect.joint_cond(params,
                                     decoder,
                                     classifier,
                                     device,
                                     What=What)
 # compute likelihood
 data["loglik"][ia1, ia2] = -loss_functions.linGauss_NLL_loss(
     torch.from_numpy(X).float(), What, params["gamma"])
 # store results
 data["ce_iu"][ia1, ia2] = -nce_iu.detach().numpy()
 data["ce_ic"][ia1, ia2] = -nce_ic.detach().numpy()
 data["ce_ju"][ia1, ia2] = -nce_ju.detach().numpy()
 data["ce_jc"][ia1, ia2] = -nce_jc.detach().numpy()
Exemplo n.º 6
0
for ia, theta_alpha in enumerate(thetas_alpha):
    for ib, theta_beta in enumerate(thetas_beta):
        print(
            'Computing causal effect for alpha=%.2f (%d/%d), beta=%.2f (%d/%d)...'
            % (theta_alpha, ia, len(thetas_alpha), theta_beta, ib,
               len(thetas_beta)))
        # form generative map for this (theta1, theta2)
        what1 = np.array([[np.cos(theta_alpha)], [np.sin(theta_alpha)]])
        what2 = np.array([[np.cos(theta_beta)], [np.sin(theta_beta)]])
        What = torch.from_numpy(np.hstack((what1, what2))).float()
        # sample-based estimate of causal effect
        CEs[ia, ib, 0] = -causaleffect.ind_uncond(
            params, decoder, classifier, device, What=What)[0]
        CEs[ia, ib, 1] = -causaleffect.ind_cond(
            params, decoder, classifier, device, What=What)[0]
        CEs[ia, ib, 2] = -causaleffect.joint_uncond(
            params, decoder, classifier, device, What=What)[0]
        CEs[ia, ib, 3] = -causaleffect.joint_cond(
            params, decoder, classifier, device, What=What)[0]

# --- save results ---
print('Done! Saving results...')
sio.savemat(
    'results/visualize_causalobj_linear.mat', {
        'CEs': CEs,
        'thetas_alpha': thetas_alpha,
        'thetas_beta': thetas_beta,
        'params': params
    })
print('Done!')

#%% make debug plot (see visualize_causalobj_plot.m for plots in paper)
sample_inputs = vaX[0:test_size]
sample_labels = vaY[0:test_size]
sample_inputs_torch = torch.from_numpy(sample_inputs)
sample_inputs_torch = sample_inputs_torch.permute(0,3,1,2).float().to(device)     
ntrain = X.shape[0]

# --- load VAE ---
from models.CVAE import Decoder, Encoder
checkpoint_vae = torch.load(vae_file, map_location=device)
encoder = Encoder(K+L,c_dim,img_size).to(device)
decoder = Decoder(K+L,c_dim,img_size).to(device)
encoder.apply(util.weights_init_normal)
decoder.apply(util.weights_init_normal)

# --- load classifier ---
from models.CNN_classifier import CNN
checkpoint_model = torch.load(classifier_file, map_location=device)
classifier = CNN(y_dim).to(device)
classifier.load_state_dict(checkpoint_model['model_state_dict_classifier'])

# --- compute causal effect ---
params_old = {'Nalpha' : 50, 'Nbeta' : 50, 'decoder_net' : 'VAE', 'z_dim' : K+L, 'alpha_dim' : K, 'y_dim' : M}
params = {'Nalpha' : 50, 'Nbeta' : 50, 'K' : K, 'L' : L, 'M' : M}
ntrials = 10
for i in range(ntrials):
    encoder.apply(util.weights_init_normal)
    decoder.apply(util.weights_init_normal)
    Iold = ceold.joint_uncond(params_old, decoder, classifier, device)[0].detach().numpy()
    I = ce.joint_uncond(params, decoder, classifier)[0].detach().numpy()
    print('Trial %d/%d: old=%f, new=%f (err=%g)' % \
        (i, ntrials, Iold, I, np.linalg.norm(I-Iold)/np.linalg.norm(Iold)))