Beispiel #1
0
def main():

    # Set up simulation parameters
    batch_size = 1600  # set batch size
    r = 4              # the grid dimension for the output tests
    test_split = r*r   # number of testing samples to use
    sigma = 0.2        # the noise std
    ndata = 64         # number of data samples
    usepars = [0,1,2,3]    # parameter indices to use
    seed = 1           # seed for generating data
    run_label='gpu0'
    out_dir = "/home/hunter.gabbard/public_html/CBC/cINNamon/gausian_results/multipar/%s/" % run_label

    # generate data
    pos, labels, x, sig, parnames = data.generate(
        tot_dataset_size=2**20,
        ndata=ndata,
        usepars=usepars,
        sigma=sigma,
        seed=seed
    )
    print('generated data')

    # seperate the test data for plotting
    pos_test = pos[-test_split:]
    labels_test = labels[-test_split:]
    sig_test = sig[-test_split:]

    # plot the test data examples
    plt.figure(figsize=(6,6))
    fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row')
    cnt = 0
    for i in range(r):
        for j in range(r):
            axes[i,j].plot(x,np.array(labels_test[cnt,:]),'.')
            axes[i,j].plot(x,np.array(sig_test[cnt,:]),'-')
            cnt += 1
            axes[i,j].axis([0,1,-1.5,1.5])
            axes[i,j].set_xlabel('time') if i==r-1 else axes[i,j].set_xlabel('')
            axes[i,j].set_ylabel('h(t)') if j==0 else axes[i,j].set_ylabel('')
    plt.savefig('%stest_distribution.png' % out_dir,dpi=360)
    plt.close()

    # precompute true posterior samples on the test data
    cnt = 0
    N_samp = 1000
    ndim_x = len(usepars)
    samples = np.zeros((r*r,N_samp,ndim_x))
    for i in range(r):
        for j in range(r):
            samples[cnt,:,:] = data.get_lik(np.array(labels_test[cnt,:]).flatten(),sigma=sigma,usepars=usepars,Nsamp=N_samp)
            print(samples[cnt,:10,:])
            cnt += 1

    # initialize plot for showing testing results
    fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row')

    for k in range(ndim_x):
        parname1 = parnames[k]
        for nextk in range(ndim_x):
            parname2 = parnames[nextk]
            if nextk>k:
                cnt = 0
                for i in range(r):
                     for j in range(r):

                         # plot the samples and the true contours
                         axes[i,j].clear()
                         axes[i,j].scatter(samples[cnt,:,k], samples[cnt,:,nextk],c='b',s=0.5,alpha=0.5)
                         axes[i,j].plot(pos_test[cnt,k],pos_test[cnt,nextk],'+c',markersize=8)
                         axes[i,j].set_xlim([0,1])
                         axes[i,j].set_ylim([0,1])
                         axes[i,j].set_xlabel(parname1) if i==r-1 else axes[i,j].set_xlabel('')
                         axes[i,j].set_ylabel(parname2) if j==0 else axes[i,j].set_ylabel('')
                         
                         cnt += 1

                # save the results to file
                fig.canvas.draw()
                plt.savefig('%strue_samples_%d%d.png' % (out_dir,k,nextk),dpi=360)

    # setting up the model 
    ndim_x = len(usepars)        # number of posterior parameter dimensions (x,y)
    ndim_y = ndata    # number of label dimensions (noisy data samples)
    ndim_z = 4        # number of latent space dimensions?
    ndim_tot = max(ndim_x,ndim_y+ndim_z)     # must be > ndim_x and > ndim_y + ndim_z

    # define different parts of the network
    # define input node
    inp = InputNode(ndim_tot, name='input')

    # define hidden layer nodes
    t1 = Node([inp.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.2}})

    #t1 = Node([inp.out0], rev_multiplicative_layer,
    #          {'F_class': F_conv, 'clamp': 2.0,
    #           'F_args': {'kernel_size': 3,'leaky_slope': 0.1}})

    #def __init__(self, dims_in, F_class=F_fully_connected, F_args={},
    #             clamp=5.):
    #    super(rev_multiplicative_layer, self).__init__()
    #    channels = dims_in[0][0]
    #
    #    self.split_len1 = channels // 2
    #    self.split_len2 = channels - channels // 2
    #    self.ndims = len(dims_in[0])
    #
    #    self.clamp = clamp
    #    self.max_s = exp(clamp)
    #    self.min_s = exp(-clamp)
    #
    #    self.s1 = F_class(self.split_len1, self.split_len2, **F_args)
    #    self.t1 = F_class(self.split_len1, self.split_len2, **F_args)
    #    self.s2 = F_class(self.split_len2, self.split_len1, **F_args)
    #    self.t2 = F_class(self.split_len2, self.split_len1, **F_args)

    t2 = Node([t1.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.2}})

    t3 = Node([t2.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.2}})

    t4 = Node([t3.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}})

    # define output layer node
    outp = OutputNode([t4.out0], name='output')

    nodes = [inp, t1, t2, t3, t4, outp]
    model = ReversibleGraphNet(nodes)

    # Train model
    # Training parameters
    n_epochs = 10000
    meta_epoch = 12 # what is this???
    n_its_per_epoch = 12
    batch_size = 1600

    lr = 1e-2
    gamma = 0.01**(1./120)
    l2_reg = 2e-5

    y_noise_scale = 3e-2
    zeros_noise_scale = 3e-2

    # relative weighting of losses:
    lambd_predict = 300. # forward pass
    lambd_latent = 300.  # laten space
    lambd_rev = 400.     # backwards pass

    # padding both the data and the latent space
    # such that they have equal dimension to the parameter space
    #pad_x = torch.zeros(batch_size, ndim_tot - ndim_x)
    #pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.8, 0.8),
                             eps=1e-04, weight_decay=l2_reg)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=meta_epoch,
                                                gamma=gamma)

    # define the three loss functions
    loss_backward = MMD_multiscale
    loss_latent = MMD_multiscale
    loss_fit = fit

    # set up training set data loader
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(pos[test_split:], labels[test_split:]),
        batch_size=batch_size, shuffle=True, drop_last=True)

    # initialisation of network weights
    #for mod_list in model.children():
    #    for block in mod_list.children():
    #        for coeff in block.children():
    #            coeff.fc3.weight.data = 0.01*torch.randn(coeff.fc3.weight.shape)
    #model.to(device)

    # start training loop            
    try:
        t_start = time()
        olvec = np.zeros((r,r,int(n_epochs/10)))
        s = 0
        # loop over number of epochs
        for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80):

            scheduler.step()

            # Initially, the l2 reg. on x and z can give huge gradients, set
            # the lr lower for this
            if i_epoch < 0:
                print('inside this iepoch<0 thing')
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr * 1e-2

            # train the model
            train(model,train_loader,n_its_per_epoch,zeros_noise_scale,batch_size,
                ndim_tot,ndim_x,ndim_y,ndim_z,y_noise_scale,optimizer,lambd_predict,
                loss_fit,lambd_latent,loss_latent,lambd_rev,loss_backward,i_epoch)

            # loop over a few cases and plot results in a grid
            if np.remainder(i_epoch,10)==0:
                for k in range(ndim_x):
                    parname1 = parnames[k]
                    for nextk in range(ndim_x):
                        parname2 = parnames[nextk]
                        if nextk>k:
                            cnt = 0

                            # initialize plot for showing testing results
                            fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row')

                            for i in range(r):
                                for j in range(r):

                                    # convert data into correct format
                                    y_samps = np.tile(np.array(labels_test[cnt,:]),N_samp).reshape(N_samp,ndim_y)
                                    y_samps = torch.tensor(y_samps, dtype=torch.float)
                                    y_samps += y_noise_scale * torch.randn(N_samp, ndim_y)
                                    y_samps = torch.cat([torch.randn(N_samp, ndim_z), zeros_noise_scale * 
                                        torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z),
                                        y_samps], dim=1)
                                    y_samps = y_samps.to(device)

                                    # use the network to predict parameters
                                    rev_x = model(y_samps, rev=True)
                                    rev_x = rev_x.cpu().data.numpy()

                                    # compute the n-d overlap
                                    if k==0 and nextk==1:
                                        ol = data.overlap(samples[cnt,:,:ndim_x],rev_x[:,:ndim_x])
                                        olvec[i,j,s] = ol                                     

                                    # plot the samples and the true contours
                                    axes[i,j].clear()
                                    axes[i,j].scatter(samples[cnt,:,k], samples[cnt,:,nextk],c='b',s=0.2,alpha=0.5)
                                    axes[i,j].scatter(rev_x[:,k], rev_x[:,nextk],c='r',s=0.2,alpha=0.5)
                                    axes[i,j].plot(pos_test[cnt,k],pos_test[cnt,nextk],'+c',markersize=8)
                                    axes[i,j].set_xlim([0,1])
                                    axes[i,j].set_ylim([0,1])
                                    oltxt = '%.2f' % olvec[i,j,s]
                                    axes[i,j].text(0.90, 0.95, oltxt,
                                        horizontalalignment='right',
                                        verticalalignment='top',
                                        transform=axes[i,j].transAxes)
                                    matplotlib.rc('xtick', labelsize=8)     
                                    matplotlib.rc('ytick', labelsize=8) 
                                    axes[i,j].set_xlabel(parname1) if i==r-1 else axes[i,j].set_xlabel('')
                                    axes[i,j].set_ylabel(parname2) if j==0 else axes[i,j].set_ylabel('')
                                    cnt += 1

                            # save the results to file
                            fig.canvas.draw()
                            plt.savefig('%sposteriors_%d%d_%04d.png' % (out_dir,k,nextk,i_epoch),dpi=360)
                            plt.savefig('%slatest_%d%d.png' % (out_dir,k,nextk),dpi=360)
                            plt.close()
                s += 1

            # plot overlap results
            if np.remainder(i_epoch,10)==0:                
                fig, axes = plt.subplots(1,figsize=(6,6))
                for i in range(r):
                    for j in range(r):
                        axes.semilogx(10*np.arange(olvec.shape[2]),olvec[i,j,:],alpha=0.5)
                axes.grid()
                axes.set_ylabel('overlap')
                axes.set_xlabel('epoch')
                axes.set_ylim([0,1])
                plt.savefig('%soverlap.png' % out_dir,dpi=360)
                plt.close()

    except KeyboardInterrupt:
        pass
    finally:
        print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")
Beispiel #2
0
def plot_y_test(model,Nsamp,usepars,sigma,ndim_x,ndim_y,ndim_z,ndim_tot,outdir,r,i_epoch,conv=False,model_f=None,model_r=None,do_double_nn=False,do_cnn=False):
    """
    Plot examples of test y-data generation
    """

    # generate test data
    x_test, y_test, x, sig_test, parnames = data_maker.generate(
        tot_dataset_size=Nsamp,
        ndata=ndim_y,
        usepars=usepars,
        sigma=sigma,
        seed=1
    )

    out_shape = [-1,ndim_tot]
    if conv==True:
        in_shape = [-1,1,ndim_tot]
    else:
        in_shape = [-1,ndim_tot]
    fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row')

    # run the x test data through the model
    x = torch.tensor(x_test[:r*r,:],dtype=torch.float,device=dev).clone().detach()
    y_test = torch.tensor(y_test[:r*r,:],dtype=torch.float,device=dev).clone().detach()
    sig_test = torch.tensor(sig_test[:r*r,:],dtype=torch.float,device=dev).clone().detach()

    # make the new padding for the noisy data and latent vector data
    pad_x = torch.zeros(r*r,ndim_tot-ndim_x-ndim_y,device=dev)
 
    # make a padded zy vector (with all new noise)
    x_padded = torch.cat((x,pad_x,y_test-sig_test),dim=1)

    # apply forward model to the x data
    if do_double_nn:
        if do_cnn:
            data = torch.cat((x,y_test-sig_test), dim=1)
            output = model_f(data.reshape(data.shape[0],1,data.shape[1]))#.reshape(out_shape)
            output_y = output[:,:ndim_y]  # extract the model output y
        else:
            output = model_f(torch.cat((x,y_test-sig_test), dim=1))#.reshape(out_shape)
            output_y = output[:,:ndim_y]  # extract the model output y
    else:
        output = model(x_padded.reshape(in_shape))#.reshape(out_shape)
        output_y = output[:,model.outSchema.timeseries]  # extract the model output y
    y = output_y.cpu().data.numpy()

    cnt = 0
    for i in range(r):
        for j in range(r):

            axes[i,j].clear()
            axes[i,j].plot(np.arange(ndim_y)/float(ndim_y),y[cnt,:],'b-')
            axes[i,j].plot(np.arange(ndim_y)/float(ndim_y),y_test[cnt,:].cpu().data.numpy(),'k',alpha=0.5)
            axes[i,j].set_xlim([0,1])
            #matplotlib.rc('xtick', labelsize=5)
            #matplotlib.rc('ytick', labelsize=5)
            axes[i,j].set_xlabel('t') if i==r-1 else axes[i,j].set_xlabel('')
            axes[i,j].set_ylabel('y') if j==0 else axes[i,j].set_ylabel('')
            if i==0 and j==0:
                axes[i,j].legend(('pred y','y'))
            cnt += 1

    fig.canvas.draw()
    fig.savefig('%s/ytest_%04d.png' % (outdir,i_epoch),dpi=360)
    fig.savefig('%s/latest/latest_ytest.png' % outdir,dpi=360)
    plt.close(fig)
    return
Beispiel #3
0
def main():
    ## If file exists, delete it ##
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    else:    ## Show a message ##
        print("Attention: %s file not found" % out_dir)

    # setup output directory - if it does not exist
    os.makedirs('%s' % out_dir)
    os.makedirs('%s/latest' % out_dir)
    os.makedirs('%s/animations' % out_dir)


    # generate data
    if not load_dataset:
        pos, labels, x, sig, parnames = data_maker.generate(
            tot_dataset_size=tot_dataset_size,
            ndata=ndata,
            usepars=usepars,
            sigma=sigma,
            seed=seed
        )
        print('generated data')

        hf = h5py.File('benchmark_data_%s.h5py' % run_label, 'w')
        hf.create_dataset('pos', data=pos)
        hf.create_dataset('labels', data=labels)
        hf.create_dataset('x', data=x)
        hf.create_dataset('sig', data=sig)
        hf.create_dataset('parnames', data=np.string_(parnames))

    data = AtmosData([dataLocation1], test_split, resampleWl=None)
    data.split_data_and_init_loaders(batchsize)

    # seperate the test data for plotting
    pos_test = data.pos_test
    labels_test = data.labels_test
    sig_test = data.sig_test

    ndim_x = len(usepars)
    print('Computing MCMC posterior samples')
    if do_mcmc or not load_dataset:
        # precompute true posterior samples on the test data
        cnt = 0
        samples = np.zeros((r*r,N_samp,ndim_x))
        for i in range(r):
            for j in range(r):
                samples[cnt,:,:] = data_maker.get_lik(np.array(labels_test[cnt,:]).flatten(),sigma=sigma,usepars=usepars,Nsamp=N_samp)
                print(samples[cnt,:10,:])
                cnt += 1

        # save computationaly expensive mcmc/waveform runs
        if load_dataset==True:
            # define names of parameters to have PE done on them
           
            parnames=['A','t0','tau','phi','w']
            names = [parnames[int(i)] for i in usepars]
            hf = h5py.File('benchmark_data_%s.h5py' % run_label, 'w')
            hf.create_dataset('pos', data=data.pos)
            hf.create_dataset('labels', data=data.labels)
            hf.create_dataset('x', data=data.x)
            hf.create_dataset('sig', data=data.sig)
            hf.create_dataset('parnames', data=np.string_(names))
        hf.create_dataset('samples', data=np.string_(samples))
        hf.close()

    else:
        samples=h5py.File(dataLocation1, 'r')['samples'][:]
        parnames=h5py.File(dataLocation1, 'r')['parnames'][:]

    # plot the test data examples
    plt.figure(figsize=(6,6))
    fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row')
    cnt = 0
    for i in range(r):
        for j in range(r):
            axes[i,j].plot(data.x,np.array(labels_test[cnt,:]),'-', label='noisy')
            axes[i,j].plot(data.x,np.array(sig_test[cnt,:]),'-', label='noise-free')
            axes[i,j].legend(loc='upper left')
            cnt += 1
            axes[i,j].axis([0,1,-1.5,1.5])
            axes[i,j].set_xlabel('time') if i==r-1 else axes[i,j].set_xlabel('')
            axes[i,j].set_ylabel('h(t)') if j==0 else axes[i,j].set_ylabel('')
    plt.savefig('%s/test_distribution.png' % out_dir,dpi=360)
    plt.close()

    # initialize plot for showing testing results
    fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row')

    for k in range(ndim_x):
        parname1 = parnames[k]
        for nextk in range(ndim_x):
            parname2 = parnames[nextk]
            if nextk>k:
                cnt = 0
                for i in range(r):
                     for j in range(r):

                         # plot the samples and the true contours
                         axes[i,j].clear()
                         axes[i,j].scatter(samples[cnt,:,k], samples[cnt,:,nextk],c='b',s=0.5,alpha=0.5, label='MCMC')
                         axes[i,j].plot(pos_test[cnt,k],pos_test[cnt,nextk],'+c',markersize=8, label='MCMC Truth')
                         axes[i,j].set_xlim([0,1])
                         axes[i,j].set_ylim([0,1])
                         axes[i,j].set_xlabel(parname1) if i==r-1 else axes[i,j].set_xlabel('')
                         axes[i,j].set_ylabel(parname2) if j==0 else axes[i,j].set_ylabel('')
                         axes[i,j].legend(loc='upper left')
                         
                         cnt += 1

                # save the results to file
                fig.canvas.draw()
                plt.savefig('%s/true_samples_%d%d.png' % (out_dir,k,nextk),dpi=360)


    def store_pars(f,pars):
        for i in pars.keys():
            f.write("%s: %s\n" % (i,str(pars[i])))
        f.close()

    # store hyperparameters for posterity
    f=open("%s_run-pars.txt" % run_label,"w+")
    pars_to_store={"sigma":sigma,"ndata":ndata,"T":T,"seed":seed,"n_neurons":n_neurons,"bound":bound,"conv_nn":conv_nn,"filtsize":filtsize,"dropout":dropout,
                   "clamp":clamp,"ndim_z":ndim_z,"tot_epoch":tot_epoch,"lr":lr, "latentAlphas":latentAlphas, "backwardAlphas":backwardAlphas,
                   "zerosNoiseScale":zerosNoiseScale,"wPred":wPred,"wLatent":wLatent,"wRev":wRev,"tot_dataset_size":tot_dataset_size,
                   "numInvLayers":numInvLayers,"batchsize":batchsize}
    store_pars(f,pars_to_store)

    if extra_z: inRepr = [('amp', 1), ('t0', 1), ('tau', 1), ('!!PAD',), ('yNoise', data.atmosOut.shape[1])]
    else: inRepr = [('amp', 1), ('t0', 1), ('tau', 1), ('!!PAD',)]
    outRepr = [('LatentSpace', ndim_z), ('!!PAD',), ('timeseries', data.atmosOut.shape[1])]
    if do_double_nn:
        model_f = nn_double_f((ndim_x+ndim_y),(ndim_y+ndim_z))
        model_r = nn_double_r((ndim_y+ndim_z),(ndim_x+ndim_y))
    else:
        model = RadynversionNet(inRepr, outRepr, dropout=dropout, zeroPadding=0, minSize=ndim_tot, numInvLayers=numInvLayers)

    # Construct the class that trains the model, the initial weighting between the losses, learning rate, and the initial number of epochs to train for.

    # load previous model if asked
    if load_model: model.load_state_dict(torch.load('models/gpu5_model.pt')) #% run_label))

    if do_double_nn:
        trainer = DoubleNetTrainer(model_f, model_r, data, dev, load_model=load_model)
        trainer.training_params(tot_epoch, lr=lr, fadeIn=fadeIn,
                                loss_latent=Loss.mmd_multiscale_on(dev, alphas=latentAlphas),
                                loss_fit=Loss.mse,ndata=ndata,sigma=sigma,seed=seed,batchSize=batchsize,usepars=usepars)
    else:
        trainer = RadynversionTrainer(model, data, dev, load_model=load_model)
        trainer.training_params(tot_epoch, lr=lr, fadeIn=fadeIn, zerosNoiseScale=zerosNoiseScale, wPred=wPred, wLatent=wLatent, wRev=wRev,
                                loss_latent=Loss.mmd_multiscale_on(dev, alphas=latentAlphas),
                                loss_backward=Loss.mmd_multiscale_on(dev, alphas=backwardAlphas),
                                loss_fit=Loss.mse,ndata=ndata,sigma=sigma,seed=seed,n_neurons=n_neurons,batchSize=batchsize,usepars=usepars,
                                y_noise_scale=y_noise_scale)
    totalEpochs = 0

    # Train the model for these first epochs with a nice graph that updates during training.

    losses = []
    wRevScale_tot = []
    beta_score_hist=[]
    beta_score_loop_hist=[]
    lossVec = [[] for _ in range(4)]
    lossLabels = ['L2 Line', 'MMD Latent', 'MMD Reverse', 'L2 Reverse']
    out = None
    alphaRange, mmdF, mmdB, idxF, idxB = [1,1], [1,1], [1,1], 0, 0

    try:
        tStart = time()
        olvec = np.zeros((r,r,int(trainer.numEpochs/plot_cadence)))
        adksVec = np.zeros((r,r,ndim_x,4,int(trainer.numEpochs/plot_cadence)))
        s = 0

        for epoch in range(trainer.numEpochs):
            print('Epoch %s/%s' % (str(epoch),str(trainer.numEpochs)))
            totalEpochs += 1

            if do_double_nn:
                trainer.scheduler_f.step()
        
                loss, indLosses = trainer.train(epoch,gen_inf_temp=gen_inf_temp,extra_z=extra_z, do_cnn=do_cnn)
            else:
                trainer.scheduler.step()

                loss, indLosses = trainer.train(epoch,gen_inf_temp=gen_inf_temp,extra_z=extra_z,do_covar=do_covar)

            if do_double_nn:
                # save trained model
                torch.save(model_f.state_dict(), 'models/%s_model_f.pt' % run_label)
                torch.save(model_r.state_dict(), 'models/%s_model_r.pt' % run_label)
            else:
                # save trained model
                torch.save(model.state_dict(), 'models/%s_model.pt' % run_label)

            # loop over a few cases and plot results in a grid
            if np.remainder(epoch,plot_cadence)==0:
                for k in range(ndim_x):
                    parname1 = parnames[k]
                    for nextk in range(ndim_x):
                        parname2 = parnames[nextk]
                        if nextk>k:
                            cnt = 0

                            # initialize plot for showing testing results
                            fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row')

                            for i in range(r):
                                for j in range(r):

                                    # convert data into correct format
                                    y_samps = np.tile(np.array(labels_test[cnt,:]),N_samp).reshape(N_samp,ndim_y)
                                    y_samps = torch.tensor(y_samps, dtype=torch.float)

                                    # add noise to y data (why?)
                                    #y = y_samps + y_noise_scale * torch.randn(N_samp, ndim_y, device=dev)
                                    if do_double_nn:
                                        y_samps = torch.cat([torch.randn(N_samp, ndim_z),
                                            y_samps], dim=1)
                                    else:
                                        y_samps = torch.cat([torch.randn(N_samp, ndim_z), zerosNoiseScale * 
                                            torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z),
                                            y_samps], dim=1)
                                    y_samps = y_samps.to(dev)

                                    # use the network to predict parameters
                                    if do_double_nn:
                                        if do_cnn: 
                                            y_samps = y_samps.reshape(y_samps.shape[0],1,y_samps.shape[1]) 
                                            rev_x = model_r(y_samps)
                                        else:
                                            rev_x = model_r(y_samps)
                                    else: rev_x = model(y_samps, rev=True)
                                    rev_x = rev_x.cpu().data.numpy()

                                    # compute the n-d overlap
                                    if k==0 and nextk==1:
                                        ol = data_maker.overlap(samples[cnt,:,:ndim_x],rev_x[:,:ndim_x])
                                        olvec[i,j,s] = ol                                     

                                        # print A-D and K-S test
                                        ks_mcmc_arr, ks_inn_arr, ad_mcmc_arr, ad_inn_arr = result_stat_tests(rev_x[:, usepars], samples[cnt,:,:ndim_x], cnt, parnames)
                                        for p in usepars:
                                            for c in range(4):
                                                adksVec[i,j,p,c,s] = np.array([ks_mcmc_arr,ks_inn_arr,ad_mcmc_arr,ad_inn_arr])[c,p] 
                      
                                    # plot the samples and the true contours
                                    axes[i,j].clear()
                                    if latent==True:
                                        colors = z.cpu().detach().numpy()
                                        colors = np.linalg.norm(colors,axis=1)
                                        axes[i,j].scatter(rev_x[:,k], rev_x[:,nextk],c=colors,s=1.0,cmap='hsv',alpha=0.75, label='INN')
                                    else:
                                        axes[i,j].scatter(samples[cnt,:,k], samples[cnt,:,nextk],c='b',s=0.2,alpha=0.5, label='MCMC')
                                        axes[i,j].scatter(rev_x[:,k], rev_x[:,nextk],c='r',s=0.2,alpha=0.5, label='INN')
                                        axes[i,j].set_xlim([0,1])
                                        axes[i,j].set_ylim([0,1])                                   
                                    axes[i,j].plot(pos_test[cnt,k],pos_test[cnt,nextk],'+c',markersize=8, label='Truth')
                                    oltxt = '%.2f' % olvec[i,j,s]
                                    axes[i,j].text(0.90, 0.95, oltxt,
                                        horizontalalignment='right',
                                        verticalalignment='top',
                                            transform=axes[i,j].transAxes)
                                    matplotlib.rc('xtick', labelsize=8)     
                                    matplotlib.rc('ytick', labelsize=8) 
                                    axes[i,j].set_xlabel(parname1) if i==r-1 else axes[i,j].set_xlabel('')
                                    axes[i,j].set_ylabel(parname2) if j==0 else axes[i,j].set_ylabel('')
                                    if i == 0 and j == 0: axes[i,j].legend(loc='upper left', fontsize='x-small')
                                    cnt += 1

                            # save the results to file
                            fig.canvas.draw()
                            if latent==True:
                                plt.savefig('%s/latent_map_%d%d_%04d.png' % (out_dir,k,nextk,epoch),dpi=360)
                                plt.savefig('%s/latest/latent_map_%d%d.png' % (out_dir,k,nextk),dpi=360)
                            else:
                                plt.savefig('%s/posteriors_%d%d_%04d.png' % (out_dir,k,nextk,epoch),dpi=360)
                                plt.savefig('%s/latest/posteriors_%d%d.png' % (out_dir,k,nextk),dpi=360)
                            plt.close(fig)
                s += 1

            # plot overlap results
            if np.remainder(epoch,plot_cadence)==0:                
                fig, axes = plt.subplots(1,figsize=(6,6))
                for i in range(r):
                    for j in range(r):
                        color = next(axes._get_lines.prop_cycler)['color']
                        axes.semilogx(np.arange(epoch, step=plot_cadence),olvec[i,j,:int((epoch)/plot_cadence)],alpha=0.5, color=color)
                        axes.plot([int(epoch)],[olvec[i,j,int(epoch/plot_cadence)]],'.', color=color)
                axes.grid()
                axes.set_ylabel('overlap')
                axes.set_xlabel('epoch')
                axes.set_ylim([0,1])
                plt.savefig('%s/latest/overlap_logscale.png' % out_dir, dpi=360)
                plt.close(fig)      

                fig, axes = plt.subplots(1,figsize=(6,6))
                for i in range(r):
                    for j in range(r):
                        color = next(axes._get_lines.prop_cycler)['color']
                        axes.plot(np.arange(epoch, step=plot_cadence),olvec[i,j,:int((epoch)/plot_cadence)],alpha=0.5, color=color)
                        axes.plot([int(epoch)],[olvec[i,j,int(epoch/plot_cadence)]],'.', color=color)
                axes.grid()
                axes.set_ylabel('overlap')
                axes.set_xlabel('epoch')
                axes.set_ylim([0,1])
                plt.savefig('%s/latest/overlap.png' % out_dir, dpi=360)
                plt.close(fig)

                if do_double_nn:
                    # plot predicted time series vs. actually time series examples
                    model=None
                    plot_y_test(model,N_samp,usepars,sigma,ndim_x,ndim_y,ndim_z,ndim_tot,out_dir,r,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=do_double_nn,do_cnn=do_cnn)

                    # make y_dist_plot
                    plot_y_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,out_dir,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=do_double_nn,do_cnn=do_cnn)

                    plot_x_evolution(model,ndim_x,ndim_y,ndim_z,ndim_tot,sigma,parnames,out_dir,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=do_double_nn,do_cnn=do_cnn)

                    plot_z_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,out_dir,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=do_double_nn,do_cnn=do_cnn)

                else:
                    # plot predicted time series vs. actually time series examples
                    plot_y_test(model,N_samp,usepars,sigma,ndim_x,ndim_y,ndim_z,ndim_tot,out_dir,r,epoch,conv=False)

                    # make y_dist_plot
                    plot_y_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,out_dir,epoch,conv=False)

                    plot_x_evolution(model,ndim_x,ndim_y,ndim_z,ndim_tot,sigma,parnames,out_dir,epoch,conv=False)

                    plot_z_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,out_dir,epoch,conv=False)

                # plot evolution of y
                if not do_double_nn:
                    for c in range(ndim_x):
                        plot_y_evolution(model,c,parnames,ndim_x,ndim_y,ndim_z,ndim_tot,out_dir,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=True,do_cnn=do_cnn) 

                # plot ad and ks results [ks_mcmc_arr,ks_inn_arr,ad_mcmc_arr,ad_inn_arr]
                for p in range(ndim_x):
                    fig_ks, axis_ks = plt.subplots(1,figsize=(6,6)) 
                    fig_ad, axis_ad = plt.subplots(1,figsize=(6,6))
                    for i in range(r):
                        for j in range(r):
                            color_ks = next(axis_ks._get_lines.prop_cycler)['color'] 
                            axis_ks.semilogx(np.arange(tot_epoch, step=plot_cadence),adksVec[i,j,p,0,:],'--',alpha=0.5,color=color_ks)
                            axis_ks.semilogx(np.arange(tot_epoch, step=plot_cadence),adksVec[i,j,p,1,:],alpha=0.5,color=color_ks)
                            axis_ks.plot([int(epoch)],[adksVec[i,j,p,1,int(epoch/plot_cadence)]],'.', color=color_ks) 
                            axis_ks.set_yscale('log')

                            color_ad = next(axis_ad._get_lines.prop_cycler)['color']
                            axis_ad.semilogx(np.arange(tot_epoch, step=plot_cadence),adksVec[i,j,p,2,:],'--',alpha=0.5,color=color_ad)
                            axis_ad.semilogx(np.arange(tot_epoch, step=plot_cadence),adksVec[i,j,p,3,:],alpha=0.5,color=color_ad)
                            axis_ad.plot([int(epoch)],[adksVec[i,j,p,3,int(epoch/plot_cadence)]],'.',color=color_ad)
                            axis_ad.set_yscale('log')

                    axis_ks.set_xlabel('Epoch')
                    axis_ad.set_xlabel('Epoch')
                    axis_ks.set_ylabel('KS Statistic')
                    axis_ad.set_ylabel('AD Statistic')
                    fig_ks.savefig('%s/latest/ks_%s_stat.png' % (out_dir,parnames[p]), dpi=360)
                    fig_ad.savefig('%s/latest/ad_%s_stat.png' % (out_dir,parnames[p]), dpi=360)
                    plt.close(fig_ks)
                    plt.close(fig_ad)

            if ((epoch % 10 == 0) & (epoch>5)):

                #fig, axis = plt.subplots(4,1, figsize=(10,8))
                #fig.canvas.draw()
                #axis[0].clear()
                #axis[1].clear()
                #axis[2].clear()
                #axis[3].clear()
                for i in range(len(indLosses)):
                    lossVec[i].append(indLosses[i])
                losses.append(loss)
                #fig.suptitle('Current Loss: %.2e, min loss: %.2e' % (loss, np.nanmin(np.abs(losses))))
                #axis[0].semilogy(np.arange(len(losses)), np.abs(losses))
                #for i, lo in enumerate(lossVec):
                #    axis[1].semilogy(np.arange(len(losses)), lo, '--', label=lossLabels[i])
                #axis[1].legend(loc='upper left')
                #tNow = time()
                #elapsed = int(tNow - tStart)
                #eta = int((tNow - tStart) / (epoch + 1) * trainer.numEpochs) - elapsed

                #if epoch % 2 == 0:
                #    mses = trainer.test(samples,maxBatches=1,extra_z=extra_z)
                #    lineProfiles = mses[2]
                
                if epoch % 10 == 0 and review_mmd and epoch >=600:
                    print('Reviewing alphas')
                    alphaRange, mmdF, mmdB, idxF, idxB = trainer.review_mmd()
                
                #axis[3].semilogx(alphaRange, mmdF, label='Latent Space')
                #axis[3].semilogx(alphaRange, mmdB, label='Backward')
                #axis[3].semilogx(alphaRange[idxF], mmdF[idxF], 'ro')
                #axis[3].semilogx(alphaRange[idxB], mmdB[idxB], 'ro')
                #axis[3].legend()

                #testTime = time() - tNow
                #axis[2].plot(lineProfiles[0, model.outSchema.timeseries].cpu().numpy())
                #for a in axis:
                #    a.grid()
                #axis[3].set_xlabel('Epochs: %d, Elapsed: %d s, ETA: %d s (Testing: %d s)' % (epoch, elapsed, eta, testTime))
            
                
                #fig.canvas.draw()
                #fig.savefig('%slosses-wave-tot.pdf' % out_dir)
                #plt.close(fig)

                # make latent space plots
                """
                if epoch % plot_cadence == 0:
                    labels_z = []
                    for lab_idx in range(ndim_z):
                        labels_z.append(r"latent%d" % lab_idx)
                    fig_latent = corner.corner(lineProfiles[:, model.outSchema.LatentSpace].cpu().numpy(), 
                                               plot_contours=False, labels=labels_z)
                    fig_latent.savefig('%s/latest/latent_space.pdf' % out_dir)
                    print('Plotted latent space')
                    plt.close(fig_latent)
                """

                # make non-logscale loss plot
                fig_loss, axes_loss = plt.subplots(1,figsize=(10,8))
                wRevScale_tot.append(trainer.wRevScale)
                axes_loss.grid()
                axes_loss.set_ylabel('Loss')
                axes_loss.set_xlabel('Epochs elapsed: %s' % epoch)
                axes_loss.semilogy(np.arange(len(losses)), np.abs(losses), label='Total')
                for i, lo in enumerate(lossVec):
                    axes_loss.semilogy(np.arange(len(losses)), lo, label=lossLabels[i])
                axes_loss.semilogy(np.arange(len(losses)), wRevScale_tot, label='fadeIn')
                axes_loss.legend(loc='upper left')
                plt.savefig('%s/latest/losses.png' % out_dir)
                plt.close(fig)

                # make log scale loss plot
                fig_loss, axes_loss = plt.subplots(1,figsize=(10,8))
                axes_loss.grid()
                axes_loss.set_ylabel('Loss')
                axes_loss.set_xlabel('Epochs elapsed: %s' % epoch)
                axes_loss.plot(np.arange(len(losses)), np.abs(losses), label='Total')
                for i, lo in enumerate(lossVec):
                    axes_loss.plot(np.arange(len(losses)), lo, label=lossLabels[i])
                axes_loss.plot(np.arange(len(losses)), wRevScale_tot, label='fadeIn')
                axes_loss.set_xscale('log')
                axes_loss.set_yscale('log')
                axes_loss.legend(loc='upper left')
                plt.savefig('%s/latest/losses_logscale.png' % out_dir)
                plt.close(fig)

    except KeyboardInterrupt:
        pass
    finally:
        print("\n\nTraining took {(time()-tStart)/60:.2f} minutes\n")
Beispiel #4
0
def plot_z_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,outdir,i_epoch,conv=False,model_f=None,model_r=None,do_double_nn=False,do_cnn=False):
    """
    Plots the distribution of latent z variables
    """
    Nsamp = 250
    out_shape = [-1,ndim_tot]
    if conv==True:
        in_shape = [-1,1,ndim_tot]
    else:
        in_shape = [-1,ndim_tot]

    # generate test data
    x_test, y_test, x, sig_test, parnames = data_maker.generate(
        tot_dataset_size=Nsamp,
        ndata=ndim_y,
        usepars=usepars,
        sigma=sigma,
        seed=1
    )
   
    # run the x test data through the model
    x = torch.tensor(x_test,dtype=torch.float,device=dev).clone().detach()
    y_test = torch.tensor(y_test,dtype=torch.float,device=dev).clone().detach()
    sig_test = torch.tensor(sig_test,dtype=torch.float,device=dev).clone().detach()

    # make the new padding for the noisy data and latent vector data
    pad_x = torch.zeros(Nsamp,ndim_tot-ndim_x-ndim_y,device=dev)

    # make a padded zy vector (with all new noise)
    x_padded = torch.cat((x,pad_x,y_test-sig_test),dim=1)

    # apply forward model to the x data
    if do_double_nn:
        if do_cnn:
            data = torch.cat((x,y_test-sig_test), dim=1)
            output = model_f(data.reshape(data.shape[0],1,data.shape[1]))#.reshape(out_shape)
            output_z = output[:,ndim_y:]  # extract the model output y
        else:
            output = model_f(torch.cat((x,y_test-sig_test), dim=1))#.reshape(out_shape)
            output_z = output[:,ndim_y:]  # extract the model output y
    else:
        output = model(x_padded.reshape(in_shape))#.reshape(out_shape)
        output_z = output[:,model.outSchema.LatentSpace]  # extract the model output y
    z = output_z.cpu().data.numpy()
    C = np.cov(z.transpose())

    fig, axes = plt.subplots(1,figsize=(5,5))
    im = axes.imshow(np.abs(C))

    # We want to show all ticks...
    axes.set_xticks(np.arange(ndim_z))
    axes.set_yticks(np.arange(ndim_z))

    # Rotate the tick labels and set their alignment.
    plt.setp(axes.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(ndim_z):
        for j in range(ndim_z):
            text = axes.text(j,i,'%.2f' % C[i,j], fontsize=3,
                       ha="center",va="center",color="w")

    fig.tight_layout()
    fig.savefig('%s/cov_z_%04d.png' % (outdir,i_epoch),dpi=360)
    fig.savefig('%s/latest/latest_cov_z.png' % outdir,dpi=360)
    plt.close(fig)

    fig, axes = plt.subplots(ndim_z,ndim_z,figsize=(5,5))
    for c in range(ndim_z):
        for d in range(ndim_z):
            if d<c:
                patches = []
                axes[c,d].clear()
                matplotlib.rc('xtick', labelsize=8)
                matplotlib.rc('ytick', labelsize=8)
                axes[c,d].plot(z[:,c],z[:,d],'.r',markersize=0.5)
                circle1 = Circle((0.0, 0.0), 1.0,fill=False,linestyle='--')
                patches.append(circle1)
                circle2 = Circle((0.0, 0.0), 2.0,fill=False,linestyle='--')
                patches.append(circle2)
                circle3 = Circle((0.0, 0.0), 3.0,fill=False,linestyle='--')
                patches.append(circle3)
                p = PatchCollection(patches, alpha=0.2)
                axes[c,d].add_collection(p)
                axes[c,d].set_yticklabels([])
                axes[c,d].set_xticklabels([])
                axes[c,d].set_xlim([-3,3])
                axes[c,d].set_ylim([-3,3])
            else:
                axes[c,d].axis('off')
            axes[c,d].set_xlabel('')
            axes[c,d].set_ylabel('')

    fig.savefig('%s/scatter_z_%04d.png' % (outdir,i_epoch),dpi=360)
    fig.savefig('%s/latest/latest_scatter_z.png' % outdir,dpi=360)
    plt.close(fig)
    
    fig, axes = plt.subplots(1,figsize=(5,5))
    delta = np.transpose(z[:,:])
    dyvec = np.linspace(-10*1.0,10*1.0,250)
    for d in delta:
        plt.hist(np.array(d).flatten(),25,density=True,histtype='stepfilled',alpha=0.5)
    plt.hist(np.array(delta).flatten(),25,density=True,histtype='step',linestyle='dashed')
    plt.plot(dyvec,norm.pdf(dyvec,loc=0,scale=1.0),'k-')
    plt.xlabel('predicted z')
    plt.ylabel('p(z)') 

    fig.savefig('%s/dist_z_%04d.png' % (outdir,i_epoch),dpi=360)
    fig.savefig('%s/latest/latest_dist_z.png' % outdir,dpi=360)
    plt.close(fig)

    return
Beispiel #5
0
def plot_y_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,outdir,i_epoch,conv=False,model_f=None,model_r=None,do_double_nn=False,do_cnn=False):
    """
    Plots the joint distributions of y variables
    """
    Nsamp = 1000
    out_shape = [-1,ndim_tot]
    if conv==True:
        in_shape = [-1,1,ndim_tot]
    else:
        in_shape = [-1,ndim_tot]

    # generate test data
    x_test, y_test, x, sig_test, parnames = data_maker.generate(
        tot_dataset_size=Nsamp,
        ndata=ndim_y,
        usepars=usepars,
        sigma=sigma,
        seed=1
    )

    # run the x test data through the model
    x = torch.tensor(x_test,dtype=torch.float,device=dev).clone().detach()
    y_test = torch.tensor(y_test,dtype=torch.float,device=dev).clone().detach()
    sig_test = torch.tensor(sig_test,dtype=torch.float,device=dev).clone().detach()

    # make the new padding for the noisy data and latent vector data
    pad_x = torch.zeros(Nsamp,ndim_tot-ndim_x-ndim_y,device=dev)

    # make a padded zy vector (with all new noise)
    x_padded = torch.cat((x,pad_x,y_test-sig_test),dim=1)

    # apply forward model to the x data
    if do_double_nn:
        if do_cnn:
            data = torch.cat((x,y_test-sig_test), dim=1)
            output = model_f(data.reshape(data.shape[0],1,data.shape[1]))#.reshape(out_shape)
            output_y = output[:,:ndim_y]  # extract the model output y
        else:
            output = model_f(torch.cat((x,y_test-sig_test), dim=1))#.reshape(out_shape)
            output_y = output[:,:ndim_y]  # extract the model output y    
    else:
        output = model(x_padded.reshape(in_shape))
        output_y = output[:, model.outSchema.timeseries]
    y = output_y.cpu().data.numpy()
    sig_test = sig_test.cpu().data.numpy()
    dy = y - sig_test
    C = np.cov(dy.transpose())

    fig, axes = plt.subplots(1,figsize=(5,5))

    im = axes.imshow(C)

    # We want to show all ticks...
    axes.set_xticks(np.arange(ndim_y))
    axes.set_yticks(np.arange(ndim_y))

    # Rotate the tick labels and set their alignment.
    plt.setp(axes.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(ndim_y):
        for j in range(ndim_y):
            text = axes.text(j,i,'%.2f' % C[i,j], fontsize=3,
                       ha="center",va="center",color="w")

    fig.tight_layout()
    plt.savefig('%s/cov_y_%04d.png' % (outdir,i_epoch),dpi=360)
    plt.savefig('%s/latest/latest_cov_y.png' % outdir,dpi=360)
    plt.close(fig)

    fig, axes = plt.subplots(1,figsize=(5,5))
    delta = np.transpose(y[:,:]-sig_test[:,:])
    dyvec = np.linspace(-10*sigma,10*sigma,250)
    for d in delta:
        plt.hist(np.array(d).flatten(),25,density=True,histtype='stepfilled',alpha=0.5)
    plt.hist(np.array(delta).flatten(),25,density=True,histtype='step',linestyle='dashed')
    plt.plot(dyvec,norm.pdf(dyvec,loc=0,scale=np.sqrt(2.0)*sigma),'k-')
    plt.xlabel('y-y_pred')
    plt.ylabel('p(y-y_pred)')
    plt.savefig('%s/y_dist_%04d.png' % (outdir,i_epoch),dpi=360)
    plt.savefig('%s/latest/y_dist.png' % outdir,dpi=360)
    plt.close(fig)
    return

    return
Beispiel #6
0
    def train(self, epoch, gen_inf_temp=False, extra_z=False, do_covar=False):
        self.model.train()

        lTot = 0
        miniBatchIdx = 0
        if self.fadeIn:
            # normally at 0.4
            wRevScale = min(epoch / 400.0, 1)**3
            self.wRevScale = wRevScale
        else:
            wRevScale = 1.0
            self.wRevScale = wRevScale
        noiseScale = (1.0 - wRevScale) * self.zerosNoiseScale

        pad_fn = lambda *x: noiseScale * torch.randn(*x, device=self.dev) #+ 10 * torch.ones(*x, device=self.dev)
        randn = lambda *x: torch.randn(*x, device=self.dev)
        losses = [0, 0, 0, 0]

        for x, y, y_sig in self.atmosData.trainLoader:
            miniBatchIdx += 1

            if miniBatchIdx > self.miniBatchesPerEpoch:
                break

            # if true, generate templates on the fly during training
            if gen_inf_temp:
                del x, y
                pos, labels, _, y_sig, _ = data_maker.generate(
                                                    tot_dataset_size=2*self.batchSize,
                                                    ndata=self.ndata,
                                                    usepars=self.usepars,
                                                    sigma=self.sigma,
                                                    seed=np.random.randint(int(1e9))
                                                    )
                loader = torch.utils.data.DataLoader(
                        torch.utils.data.TensorDataset(torch.tensor(pos), torch.tensor(labels), torch.tensor(y_sig)),
                        batch_size=self.batchSize, shuffle=True, drop_last=True)

                for x, y, y_sig in loader:
                    x = x
                    y = y
                    y_sig = y_sig
                    break
                    
            n = y - y_sig
            x, y, y_sig, n = x.to(self.dev), y.to(self.dev), y_sig.to(self.dev), n.to(self.dev)
            yClean = y.clone()

            # loss factors
            loss_factor_fwd_mmd_z = 1.0 #min(epoch / 400.0, 1)**3
            loss_factor_rev_mse_n = min(epoch / 500.0, 1)**3 # 1000
            loss_factor_fwd_mse_y = min(epoch / 750.0, 1)**3  # 2000
            loss_factor_rev_mse_x = min(epoch / 10000.0, 1)**3 # 3000

            if extra_z:
                xzp = self.model.inSchema.fill({'amp': x[:, 0],
                                           't0': x[:, 1],
                                           'tau': x[:, 2],
                                           'yNoise': n[:]},
                                          zero_pad_fn=pad_fn)
            else: 
                xp = self.model.inSchema.fill({'amp': x[:, 0], 
                                           't0': x[:, 1], 
                                           'tau': x[:, 2]},
                                          zero_pad_fn=pad_fn)
            if self.y_noise_scale:
                y += self.y_noise_scale * torch.randn(self.batchSize, self.ndata, dtype=torch.float, device=self.dev)

            yzp = self.model.outSchema.fill({'timeseries': y[:], 
                                             'LatentSpace': randn},
                                            zero_pad_fn=pad_fn)
            y_sig_zp = self.model.outSchema.fill({'timeseries': y_sig[:],
                                             'LatentSpace': randn},
                                            zero_pad_fn=pad_fn)
            yzpRevRand = self.model.outSchema.fill({'timeseries': yClean[:],
                                                    'LatentSpace': randn},
                                                   zero_pad_fn=pad_fn)

            self.optim.zero_grad()

            if extra_z: out= self.model(xzp)
            else: out = self.model(xp)

            # lForward = self.wPred * (self.loss_fit(y[:, 0], out[:, self.model.outSchema.Halpha]) + 
            #                          self.loss_fit(y[:, 1], out[:, self.model.outSchema.Ca8542]))
#             lForward = self.wPred * self.loss_fit(yzp[:, :self.model.outSchema.LatentSpace[0]], out[:, :self.model.outSchema.LatentSpace[0]])

            # add z space onto x-space
            #if extra_z:
            #    out_fmse = torch.cat((self.model(yzpRevRand, rev=True)[:, self.model.inSchema.LatentSpace], xzp[:, self.model.outSchema.LatentSpace[-1]+1:]),
            #                          dim=1)
            #    out_fmse = self.model(out_fmse)
            #    lForward = self.wPred * self.loss_fit(yzp[:, self.model.outSchema.LatentSpace[-1]+1:],
            #                                          out_fmse[:, self.model.outSchema.LatentSpace[-1]+1:])
            # use a covariance loss on forward mean squared error
            if do_covar:
                # try covariance fit
                output_cov = self.cov((out[:, self.model.outSchema.LatentSpace[-1]+1:]-y_sig_zp[:, self.model.outSchema.LatentSpace[-1]+1:]).transpose(0,1))
                ycov_mat = self.sigma*self.sigma*torch.eye((self.ndata+self.n_neurons),device=self.dev)

                lForward = self.wPred * self.loss_fit(output_cov.flatten(),
                                                      ycov_mat.flatten())

            else:
                # compute mean squared error on only y
                lForward = loss_factor_fwd_mse_y * self.wPred * self.loss_fit(yzp[:, self.model.outSchema.LatentSpace[-1]+1:], 
                                                      out[:, self.model.outSchema.LatentSpace[-1]+1:])
            losses[0] += lForward.data.item() / self.wPred
            #lForward_extra = self.wPred * self.loss_fit(yzp[:, self.model.outSchema.LatentSpace[-1]+1:],
            #                                          out[:, self.model.outSchema.LatentSpace[-1]+1:])
            #lForward += lForward_extra 
            #losses[0] += lForward_extra.data.item() / self.wPred            

            #if do_covar:
            do_z_covar=False
            if do_z_covar:
                # compute MMD loss on forward time series prediction
                lforward21Pred=out[:, self.model.outSchema.timeseries].data
                lforward21Target=yzp[:, self.model.outSchema.timeseries]
                lForward21 = self.wLatent * self.loss_latent(lforward21Pred, lforward21Target)

                losses[1] += lForward21.data.item() / self.wLatent
                lForward += lForward21

                # compute variance loss on z (2nd moment)
                lforward22Pred = self.cov((out[:, self.model.outSchema.LatentSpace]).transpose(0,1))
                lforward22Target = 1.0 * torch.eye((out[:, self.model.outSchema.LatentSpace].shape[1]),device=self.dev)

                lForward22 = self.wLatent * self.loss_fit(lforward22Pred.flatten(),
                                                        lforward22Target.flatten())
                
                losses[1] += lForward22.data.item() / self.wLatent
                lForward += lForward22

                # compute mean loss (1st moment) on z
                lForwardFirstMom = self.wLatent * self.loss_fit(torch.mean(out[:, self.model.outSchema.LatentSpace]),
                                                                torch.tensor(0.0))

                losses[1] += lForwardFirstMom.data.item() / self.wLatent
                lForward += lForwardFirstMom

                # compute 3rd moment loss on z
                lForwardThirdMom = self.wLatent * self.loss_fit(torch.mean(out[:, self.model.outSchema.LatentSpace])**3 + 3*torch.mean(out[:, self.model.outSchema.LatentSpace])*lforward22Pred.flatten(),
                                                                torch.tensor(0.0))

                losses[1] += lForwardThirdMom.data.item() / self.wLatent
                lForward += lForwardThirdMom

                # compute 4th moment loss on z
                lForwardFourthMom = self.wLatent * self.loss_fit(torch.mean(out[:, self.model.outSchema.LatentSpace])**4 + 6*(torch.mean(out[:, self.model.outSchema.LatentSpace])**2)*lforward22Pred.flatten() + (3*lforward22Pred.flatten()**2),
                                                                (3.0 * torch.eye((out[:, self.model.outSchema.LatentSpace].shape[1]),device=self.dev)).flatten())

                losses[1] += lForwardFourthMom.data.item() / self.wLatent
                lForward += lForwardFourthMom

                # compute 5th moment loss on z
                lForwardFifthMom = self.wLatent * self.loss_fit(torch.mean(out[:, self.model.outSchema.LatentSpace])**5 + 10*(torch.mean(out[:, self.model.outSchema.LatentSpace])**3)*lforward22Pred.flatten() + (15*torch.mean(out[:, self.model.outSchema.LatentSpace])*lforward22Pred.flatten()**2),
                                                                (0.0 * torch.eye((out[:, self.model.outSchema.LatentSpace].shape[1]),device=self.dev)).flatten())

                losses[1] += lForwardFifthMom.data.item() / self.wLatent
                lForward += lForwardFifthMom

            else:
                # compute forward MMD on z data
                outLatentGradOnly = torch.cat((out[:, self.model.outSchema.timeseries].data, 
                                               out[:, self.model.outSchema.LatentSpace]), 
                                              dim=1)
                unpaddedTarget = torch.cat((yzp[:, self.model.outSchema.timeseries], 
                                            yzp[:, self.model.outSchema.LatentSpace]), 
                                           dim=1)
            
                lForward2 = loss_factor_fwd_mmd_z * self.wLatent * self.loss_latent(out[:, self.model.outSchema.LatentSpace], yzp[:, self.model.outSchema.LatentSpace])
                losses[1] += lForward2.data.item() / self.wLatent
                lForward += lForward2
            
            lTot += lForward.data.item()

            lForward.backward()

            yzpRev = self.model.outSchema.fill({'timeseries': yClean[:],
                                                'LatentSpace': out[:, self.model.outSchema.LatentSpace].data},
                                               zero_pad_fn=pad_fn)

            if extra_z:
                outRev = self.model(yzpRev, rev=True)
                #outRev = torch.cat((outRev[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1],
                #                        outRev[:, self.model.inSchema.yNoise]),
                #                       dim=1)
                outRevRand = self.model(yzpRevRand, rev=True)
                outRevRand = torch.cat((outRevRand[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1],
                                        outRevRand[:, self.model.inSchema.yNoise]),
                                       dim=1)                
            else:
                outRev = self.model(yzpRev, rev=True)
                outRevRand = self.model(yzpRevRand, rev=True)

            # THis guy should have been OUTREVRAND!!!
            # xBack = torch.cat((outRevRand[:, self.model.inSchema.ne],
            #                    outRevRand[:, self.model.inSchema.temperature],
            #                    outRevRand[:, self.model.inSchema.vel]),
            #                   dim=1)
            # lBackward = self.wRev * wRevScale * self.loss_backward(xBack, x.reshape(self.miniBatchSize, -1))
            if extra_z:
                #xzp_bMMD=torch.cat((xzp[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1],
                #                   xzp[:, self.model.inSchema.yNoise]), dim=1)
                #lBackward = self.wRev * wRevScale * self.loss_fit(outRev,
                #                                                  xzp_bMMD)
                kld_loss = torch.nn.SmoothL1Loss(size_average=None, reduce=None, reduction='mean')
                lBackward1 = loss_factor_rev_mse_x * self.wRev * kld_loss(outRev[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1],
                                                                  xzp[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1])
                lBackward2 = loss_factor_rev_mse_n * self.wRev * kld_loss(outRev[:, self.model.inSchema.yNoise],
                                                                  xzp[:, self.model.inSchema.yNoise])
                lBackward = lBackward1 + lBackward2

            else:
                lBackward = self.wRev * wRevScale * self.loss_backward(outRev[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1], 
                                                                       xp[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1])

            scale = wRevScale if wRevScale != 0 else 1.0
            losses[2] += (lBackward1.data.item() / (self.wRev * scale)) + (lBackward1.data.item() / (self.wRev * scale))
            #TODO: may need to uncomment this
            #lBackward2 += 0.5 * self.wPred * self.loss_fit(outRev, xp)
            if extra_z:
                #lBackward2 = 0.5 * self.wPred * self.loss_fit(outRev,
                #                                               xzp_bMMD)
                losses[3] += 0.0
                lTot += lBackward.data.item()
            else:
                lBackward2 = 0.5 * self.wPred * self.loss_fit(outRev[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1], 
                                                                  xp[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1])
                losses[3] += lBackward2.data.item() / self.wPred * 2
                lBackward += lBackward2
            
                lTot += lBackward.data.item()

            lBackward.backward()

            for p in self.model.parameters():
                p.grad.data.clamp_(-15.0, 15.0)

            self.optim.step()

        losses = [l / miniBatchIdx for l in losses]
        return lTot / miniBatchIdx, losses
Beispiel #7
0
def main():
    # Set up data
    batch_size = 1600  # set batch size
    test_split = 10000  # number of testing samples to use

    # generate data
    # makes a torch.tensor() with arrays of (n_samples X parameters) and (n_samples X data)
    # labels are the colours and pos are the x,y coords
    # however, labels are 1-hot encoded
    pos, labels = data.generate(labels='all', tot_dataset_size=2**20)

    # just simply renaming the colors properly.
    #c = np.where(labels[:test_split])[1]
    #c = labels[:test_split,:]
    plt.figure(figsize=(6, 6))
    r = 4
    fig, axs = plt.subplots(r, r)
    cnt = 0
    for i in range(r):
        for j in range(r):
            axs[i, j].plot(np.arange(3) + 1, np.array(pos[cnt, :]), '.')
            axs[i, j].plot([1, 3], [labels[cnt, 0], labels[cnt, 0]], 'k-')
            axs[i, j].plot([1, 3], [
                labels[cnt, 0] + labels[cnt, 1],
                labels[cnt, 0] + labels[cnt, 1]
            ], 'k--')
            axs[i, j].plot([1, 3], [
                labels[cnt, 0] - labels[cnt, 1],
                labels[cnt, 0] - labels[cnt, 1]
            ], 'k--')
            axs[i, j].set_ylim([-1, 2])
            cnt += 1
    plt.savefig('/data/public_html/chrism/FrEIA/test_distribution.png')
    plt.close()

    # setting up the model
    ndim_tot = 16  # ?
    ndim_x = 2  # number of parameter dimensions (mu,sig)
    ndim_y = 3  # number of label dimensions (data)
    ndim_z = 2  # number of latent space dimensions?

    # define different parts of the network
    # define input node
    inp = InputNode(ndim_tot, name='input')

    # define hidden layer nodes
    t1 = Node([inp.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.0
        }
    })

    t2 = Node([t1.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.0
        }
    })

    t3 = Node([t2.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.0
        }
    })

    # define output layer node
    outp = OutputNode([t3.out0], name='output')

    nodes = [inp, t1, t2, t3, outp]
    model = ReversibleGraphNet(nodes)

    # Train model
    # Training parameters
    n_epochs = 3000
    meta_epoch = 12  # what is this???
    n_its_per_epoch = 4
    batch_size = 1600

    lr = 1e-2
    gamma = 0.01**(1. / 120)
    l2_reg = 2e-5

    y_noise_scale = 3e-2
    zeros_noise_scale = 3e-2

    # relative weighting of losses:
    lambd_predict = 300.  # forward pass
    lambd_latent = 300.  # laten space
    lambd_rev = 400.  # backwards pass

    # padding both the data and the latent space
    # such that they have equal dimension to the parameter space
    pad_x = torch.zeros(batch_size, ndim_tot - ndim_x)
    pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z)

    print(pad_x.shape, pad_yz.shape)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=(0.8, 0.8),
                                 eps=1e-04,
                                 weight_decay=l2_reg)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=meta_epoch,
                                                gamma=gamma)

    # define the three loss functions
    loss_backward = MMD_multiscale
    loss_latent = MMD_multiscale
    loss_fit = fit

    # set up test set data loader
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(
        pos[:test_split], labels[:test_split]),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)

    # set up training set data loader
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(
        pos[test_split:], labels[test_split:]),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)

    # initialisation of network weights
    for mod_list in model.children():
        for block in mod_list.children():
            for coeff in block.children():
                coeff.fc3.weight.data = 0.01 * torch.randn(
                    coeff.fc3.weight.shape)

    model.to(device)

    # initialize gif for showing training procedure
    #fig, axes = plt.subplots(1, 2, figsize=(8,4))
    #axes[0].set_xticks([])
    #axes[0].set_yticks([])
    #axes[0].set_title('Predicted labels (Forwards Process)')
    #axes[1].set_xticks([])
    #axes[1].set_yticks([])
    #axes[1].set_title('Generated Samples (Backwards Process)')
    #fig.show()
    #fig.canvas.draw()

    # number of test samples to use after training
    N_samp = 4096

    # choose test samples to use after training
    x_samps = torch.cat([x for x, y in test_loader], dim=0)[:N_samp]
    y_samps = torch.cat([y for x, y in test_loader], dim=0)[:N_samp]
    print(np.array(y_samps))
    #c = np.where(y_samps)[1]
    c = np.array(y_samps).reshape(-1, 3)
    y_samps += y_noise_scale * torch.randn(N_samp, ndim_y)
    y_samps = torch.cat([
        torch.randn(N_samp, ndim_z), zeros_noise_scale *
        torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z), y_samps
    ],
                        dim=1)
    y_samps = y_samps.to(device)
    #y_samps = np.random.normal(loc=0.5,scale=0.75,size=3).reshape(-1,3)

    # start training loop
    try:
        #     print('#Epoch \tIt/s \tl_total')
        t_start = time()
        # loop over number of epochs
        for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80):

            scheduler.step()

            # Initially, the l2 reg. on x and z can give huge gradients, set
            # the lr lower for this
            if i_epoch < 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr * 1e-2

    #         print(i_epoch, end='\t ')
            train(model, train_loader, n_its_per_epoch, zeros_noise_scale,
                  batch_size, ndim_tot, ndim_x, ndim_y, ndim_z, y_noise_scale,
                  optimizer, lambd_predict, loss_fit, lambd_latent,
                  loss_latent, lambd_rev, loss_backward, i_epoch)

            # predict the mu and sig of test data
            rev_x = model(y_samps, rev=True)
            rev_x = rev_x.cpu().data.numpy()
            #print(rev_x)

            # predict the label given a location
            #pred_c = model(torch.cat((x_samps, torch.zeros(N_samp, ndim_tot - ndim_x)),
            #                         dim=1).to(device)).data[:, -8:].argmax(dim=1)
            #pred_c = model(torch.cat((x_samps, torch.zeros(N_samp, ndim_tot - ndim_x)),
            #                         dim=1).to(device)).data[:, -1:].argmax(dim=1)

            #axes[0].clear()
            #axes[0].scatter(tmp_x_samps[:,0], tmp_x_samps[:,1], c=pred_c, cmap='Set1', s=1., vmin=0, vmax=9)
            #axes[0].axis('equal')
            #axes[0].axis([-3,3,-3,3])
            #axes[0].set_xticks([])
            #axes[0].set_yticks([])

            axes[1].clear()
            axes[1].scatter(rev_x[:, 0],
                            rev_x[:, 1],
                            c=c,
                            cmap='Set1',
                            s=1.,
                            vmin=0,
                            vmax=9)
            axes[1].axis('equal')
            axes[1].axis([-3, 3, -3, 3])
            axes[1].set_xticks([])
            axes[1].set_yticks([])

            fig.canvas.draw()
            plt.savefig('/data/public_html/chrism/FrEIA/training_pred.png')

    except KeyboardInterrupt:
        pass
    finally:
        print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")
Beispiel #8
0
    def train(self, epoch, gen_inf_temp=False, extra_z=False, do_cnn=False):
        self.model_f.train()
        self.model_r.train()

        lTot = 0
        miniBatchIdx = 0

        randn = torch.randn(self.batchSize,
                            self.ndata,
                            dtype=torch.float,
                            device=self.dev)

        optimizer_f = self.optim_f  #optim.SGD(self.model_f.parameters(), lr=0.01, weight_decay= 1e-6, momentum = 0.9, nesterov = True)
        optimizer_r = self.optim_r  #optim.SGD(self.model_r.parameters(), lr=0.01, weight_decay= 1e-6, momentum = 0.9, nesterov = True)

        losses = [0, 0, 0, 0]

        # get data
        for x, y, y_sig in self.atmosData.trainLoader:
            miniBatchIdx += 1

            if miniBatchIdx > self.miniBatchesPerEpoch:
                break

            # if true, generate templates on the fly during training
            if gen_inf_temp:
                del x, y
                pos, labels, _, y_sig, _ = data_maker.generate(
                    tot_dataset_size=2 * self.batchSize,
                    ndata=self.ndata,
                    usepars=self.usepars,
                    sigma=self.sigma,
                    seed=np.random.randint(int(1e9)))
                loader = torch.utils.data.DataLoader(
                    torch.utils.data.TensorDataset(torch.tensor(pos),
                                                   torch.tensor(labels),
                                                   torch.tensor(y_sig)),
                    batch_size=self.batchSize,
                    shuffle=True,
                    drop_last=True)

                for x, y, y_sig in loader:
                    x = x
                    y = y
                    y_sig = y_sig
                    break

            n = y - y_sig
            x, y, y_sig, n = x.to(self.dev), y.to(self.dev), y_sig.to(
                self.dev), n.to(self.dev)
            yClean = y.clone()

            optimizer_f.zero_grad()
            optimizer_r.zero_grad()

            #################
            # forward process
            #################
            ## 1. forward propagation
            data = torch.cat((x[:], n[:]), dim=1)
            if do_cnn: data = data.reshape(data.shape[0], 1, data.shape[1])
            output = self.model_f(data)

            ## 2. loss calculation
            target = torch.cat((y[:], randn), dim=1)

            loss_y = self.loss_fit(output[:, :y.shape[1]], y[:])

            ## 3. backward propagation
            loss_y.backward(retain_graph=True)

            losses[0] += loss_y.data.item()

            loss_z = self.loss_latent(output[:, y.shape[1]:], randn)

            ## 3. backward propagation
            loss_z.backward()

            ## 4. weight optimization
            optimizer_f.step()

            losses[1] += loss_z.data.item()

            #################
            # reverse process
            #################
            ## 1. forward propagation
            output = torch.cat((y[:], output[:, y.shape[1]:]), dim=1)
            if do_cnn:
                output = output.reshape(output.shape[0], 1, output.shape[1])
            output = self.model_r(output.data)

            ## 2. loss calculation
            target = torch.cat((x[:], n[:]), dim=1)
            loss_r = self.loss_fit(output, target)

            ## 3. backward propagation
            loss_r.backward()

            ## 4. weight optimization
            optimizer_r.step()

            losses[2] += loss_r.data.item()
            losses[3] += 0.0  # dummy loss for now
            lTot += losses[0] + losses[1] + losses[2] + losses[3]

        losses = [l / miniBatchIdx for l in losses]
        return lTot / miniBatchIdx, losses
Beispiel #9
0
def main():
    # generate data
    # generate data
    if not load_dataset:
        pos, labels, x, sig, parnames = data_maker.generate(
            tot_dataset_size=tot_dataset_size,
            ndata=ndata,
            usepars=usepars,
            sigma=sigma,
            seed=seed)
        print('generated data')

        hf = h5py.File('benchmark_data_%s.h5py' % run_label, 'w')
        hf.create_dataset('pos', data=pos)
        hf.create_dataset('labels', data=labels)
        hf.create_dataset('x', data=x)
        hf.create_dataset('sig', data=sig)
        hf.create_dataset('parnames', data=np.string_(parnames))

    data = AtmosData([dataLocation1], test_split, resampleWl=None)
    data.split_data_and_init_loaders(batchsize)

    # seperate the test data for plotting
    pos_test = data.pos_test
    labels_test = data.labels_test
    sig_test = data.sig_test

    ndim_x = len(usepars)
    print('Computing MCMC posterior samples')
    if do_mcmc or not load_dataset:
        # precompute true posterior samples on the test data
        cnt = 0
        samples = np.zeros((r * r, N_samp, ndim_x))
        for i in range(r):
            for j in range(r):
                samples[cnt, :, :] = data_maker.get_lik(
                    np.array(labels_test[cnt, :]).flatten(),
                    np.array(pos_test[cnt, :]),
                    out_dir,
                    cnt,
                    sigma=sigma,
                    usepars=usepars,
                    Nsamp=N_samp)
                print(samples[cnt, :10, :])
                cnt += 1

        # save computationaly expensive mcmc/waveform runs
        if load_dataset == True:
            hf = h5py.File('benchmark_data_%s.h5py' % run_label, 'w')
            hf.create_dataset('pos', data=data.pos)
            hf.create_dataset('labels', data=data.labels)
            hf.create_dataset('x', data=data.x)
            hf.create_dataset('sig', data=data.sig)
            hf.create_dataset('parnames', data=parnames)
        hf.create_dataset('samples', data=np.string_(samples))
        hf.close()

    else:
        samples = h5py.File(dataLocation1, 'r')['samples'][:]
        parnames = h5py.File(dataLocation1, 'r')['parnames'][:]

    # plot the test data examples
    plt.figure(figsize=(6, 6))
    fig, axes = plt.subplots(r, r, figsize=(6, 6), sharex='col', sharey='row')
    cnt = 0
    for i in range(r):
        for j in range(r):
            axes[i, j].plot(data.x, np.array(labels_test[cnt, :]), '.')
            axes[i, j].plot(data.x, np.array(sig_test[cnt, :]), '-')
            cnt += 1
            axes[i, j].axis([0, 1, -1.5, 1.5])
            axes[i, j].set_xlabel('time') if i == r - 1 else axes[
                i, j].set_xlabel('')
            axes[i, j].set_ylabel('h(t)') if j == 0 else axes[i,
                                                              j].set_ylabel('')
    plt.savefig('%stest_distribution.png' % out_dir, dpi=360)
    plt.close()

    # initialize plot for showing testing results
    fig, axes = plt.subplots(r, r, figsize=(6, 6), sharex='col', sharey='row')

    for k in range(ndim_x):
        parname1 = parnames[k]
        for nextk in range(ndim_x):
            parname2 = parnames[nextk]
            if nextk > k:
                cnt = 0
                for i in range(r):
                    for j in range(r):

                        # plot the samples and the true contours
                        axes[i, j].clear()
                        axes[i, j].scatter(samples[cnt, :, k],
                                           samples[cnt, :, nextk],
                                           c='b',
                                           s=0.5,
                                           alpha=0.5)
                        axes[i, j].plot(pos_test[cnt, k],
                                        pos_test[cnt, nextk],
                                        '+c',
                                        markersize=8)
                        axes[i, j].set_xlim([0, 1])
                        axes[i, j].set_ylim([0, 1])
                        axes[i, j].set_xlabel(
                            parname1) if i == r - 1 else axes[i,
                                                              j].set_xlabel('')
                        axes[i, j].set_ylabel(parname2) if j == 0 else axes[
                            i, j].set_ylabel('')

                        cnt += 1

                # save the results to file
                fig.canvas.draw()
                plt.savefig('%strue_samples_%d%d.png' % (out_dir, k, nextk),
                            dpi=360)

    def store_pars(f, pars):
        for i in pars.keys():
            f.write("%s: %s\n" % (i, str(pars[i])))
        f.close()

    # store hyperparameters for posterity
    f = open("%s_run-pars.txt" % run_label, "w+")
    pars_to_store = {
        "sigma": sigma,
        "ndata": ndata,
        "T": T,
        "seed": seed,
        "n_neurons": n_neurons,
        "bound": bound,
        "conv_nn": conv_nn,
        "filtsize": filtsize,
        "dropout": dropout,
        "clamp": clamp,
        "ndim_z": ndim_z,
        "tot_epoch": tot_epoch,
        "lr": lr,
        "latentAlphas": latentAlphas,
        "backwardAlphas": backwardAlphas,
        "zerosNoiseScale": zerosNoiseScale,
        "wPred": wPred,
        "wLatent": wLatent,
        "wRev": wRev,
        "tot_dataset_size": tot_dataset_size,
        "numInvLayers": numInvLayers,
        "batchsize": batchsize
    }
    store_pars(f, pars_to_store)

    # setup output directory - if it does not exist
    os.system('mkdir -p %s' % out_dir)

    inRepr = [('amp', 1), ('t0', 1), ('tau', 1), ('phi', 1), ('!!PAD', )]
    outRepr = [('LatentSpace', ndim_z), ('!!PAD', ),
               ('timeseries', data.atmosOut.shape[1])]
    model = RadynversionNet(inRepr,
                            outRepr,
                            dropout=dropout,
                            zeroPadding=0,
                            minSize=ndim_tot,
                            numInvLayers=numInvLayers)

    # Construct the class that trains the model, the initial weighting between the losses, learning rate, and the initial number of epochs to train for.

    trainer = RadynversionTrainer(model, data, dev)
    trainer.training_params(
        tot_epoch,
        lr=lr,
        zerosNoiseScale=zerosNoiseScale,
        wPred=wPred,
        wLatent=wLatent,
        wRev=wRev,
        loss_latent=Loss.mmd_multiscale_on(dev, alphas=latentAlphas),
        loss_backward=Loss.mmd_multiscale_on(dev, alphas=backwardAlphas),
        loss_fit=Loss.mse)
    totalEpochs = 0

    # Train the model for these first epochs with a nice graph that updates during training.

    losses = []
    beta_score_hist = []
    beta_score_loop_hist = []
    lossVec = [[] for _ in range(4)]
    lossLabels = ['L2 Line', 'MMD Latent', 'MMD Reverse', 'L2 Reverse']
    out = None
    alphaRange, mmdF, mmdB, idxF, idxB = [1, 1], [1, 1], [1, 1], 0, 0

    try:
        tStart = time()
        olvec = np.zeros((r, r, int(tot_epoch / plot_cadence)))
        s = 0

        for epoch in range(trainer.numEpochs):
            print('Epoch %s/%s' % (str(epoch), str(trainer.numEpochs)))
            totalEpochs += 1

            trainer.scheduler.step()

            loss, indLosses = trainer.train(epoch)

            # loop over a few cases and plot results in a grid
            if np.remainder(epoch, plot_cadence) == 0:
                for k in range(ndim_x):
                    parname1 = parnames[k]
                    for nextk in range(ndim_x):
                        parname2 = parnames[nextk]
                        if nextk > k:
                            cnt = 0

                            # initialize 2D plots for showing testing results
                            fig, axes = plt.subplots(r,
                                                     r,
                                                     figsize=(6, 6),
                                                     sharex='col',
                                                     sharey='row')

                            # initialize 1D plots for showing testing results
                            fig_1d, axes_1d = plt.subplots(r,
                                                           r,
                                                           figsize=(6, 6),
                                                           sharex='col',
                                                           sharey='row')

                            for i in range(r):
                                for j in range(r):

                                    # convert data into correct format
                                    y_samps = np.tile(
                                        np.array(labels_test[cnt, :]),
                                        N_samp).reshape(N_samp, ndim_y)
                                    y_samps = torch.tensor(y_samps,
                                                           dtype=torch.float)
                                    y_samps += y_noise_scale * torch.randn(
                                        N_samp, ndim_y)
                                    y_samps = torch.cat([
                                        torch.randn(N_samp,
                                                    ndim_z), zerosNoiseScale *
                                        torch.zeros(N_samp, ndim_tot - ndim_y -
                                                    ndim_z), y_samps
                                    ],
                                                        dim=1)
                                    y_samps = y_samps.to(dev)

                                    # use the network to predict parameters
                                    rev_x = model(y_samps, rev=True)
                                    rev_x = rev_x.cpu().data.numpy()

                                    # compute the n-d overlap
                                    if k == 0 and nextk == 1:
                                        ol = data_maker.overlap(
                                            samples[cnt, :, :ndim_x],
                                            rev_x[:, :ndim_x])
                                        olvec[i, j, s] = ol

                                    def confidence_bd(samp_array):
                                        """
                                        compute confidence bounds for a given array
                                        """
                                        cf_bd_sum_lidx = 0
                                        cf_bd_sum_ridx = 0
                                        cf_bd_sum_left = 0
                                        cf_bd_sum_right = 0
                                        cf_perc = 0.05

                                        cf_bd_sum_lidx = np.sort(samp_array)[
                                            int(len(samp_array) * cf_perc)]
                                        cf_bd_sum_ridx = np.sort(samp_array)[
                                            int(
                                                len(samp_array) *
                                                (1.0 - cf_perc))]

                                        return [cf_bd_sum_lidx, cf_bd_sum_ridx]

                                    # plot the 2D samples and the true contours
                                    true_cfbd_x = confidence_bd(samples[cnt, :,
                                                                        k])
                                    true_cfbd_y = confidence_bd(samples[cnt, :,
                                                                        nextk])
                                    pred_cfbd_x = confidence_bd(rev_x[:, k])
                                    pred_cfbd_y = confidence_bd(rev_x[:,
                                                                      nextk])
                                    axes[i, j].clear()
                                    axes[i, j].scatter(samples[cnt, :, k],
                                                       samples[cnt, :, nextk],
                                                       c='b',
                                                       s=0.2,
                                                       alpha=0.5)
                                    axes[i, j].scatter(rev_x[:, k],
                                                       rev_x[:, nextk],
                                                       c='r',
                                                       s=0.2,
                                                       alpha=0.5)
                                    axes[i, j].plot(pos_test[cnt, k],
                                                    pos_test[cnt, nextk],
                                                    '+c',
                                                    markersize=8)
                                    #axes[i,j].axvline(x=true_cfbd_x[0], linewidth=0.5, color='b')
                                    #axes[i,j].axvline(x=true_cfbd_x[1], linewidth=0.5, color='b')
                                    #axes[i,j].axhline(y=true_cfbd_y[0], linewidth=0.5, color='b')
                                    #axes[i,j].axhline(y=true_cfbd_y[1], linewidth=0.5, color='b')
                                    #axes[i,j].axvline(x=pred_cfbd_x[0], linewidth=0.5, color='r')
                                    #axes[i,j].axvline(x=pred_cfbd_x[1], linewidth=0.5, color='r')
                                    #axes[i,j].axhline(y=pred_cfbd_y[0], linewidth=0.5, color='r')
                                    #axes[i,j].axhline(y=pred_cfbd_y[1], linewidth=0.5, color='r')
                                    axes[i, j].set_xlim([0, 1])
                                    axes[i, j].set_ylim([0, 1])
                                    oltxt = '%.2f' % olvec[i, j, s]
                                    axes[i,
                                         j].text(0.90,
                                                 0.95,
                                                 oltxt,
                                                 horizontalalignment='right',
                                                 verticalalignment='top',
                                                 transform=axes[i,
                                                                j].transAxes)
                                    matplotlib.rc('xtick', labelsize=8)
                                    matplotlib.rc('ytick', labelsize=8)
                                    axes[i, j].set_xlabel(
                                        parname1) if i == r - 1 else axes[
                                            i, j].set_xlabel('')
                                    axes[i, j].set_ylabel(
                                        parname2) if j == 0 else axes[
                                            i, j].set_ylabel('')

                                    # plot the 1D samples and the 5% confidence bounds
                                    axes_1d[i, j].clear()
                                    axes_1d[i, j].hist(samples[cnt, :, k],
                                                       color='b',
                                                       bins=100,
                                                       alpha=0.5)
                                    axes_1d[i, j].hist(rev_x[:, k],
                                                       color='r',
                                                       bins=100,
                                                       alpha=0.5)
                                    axes_1d[i, j].axvline(x=pos_test[cnt, k],
                                                          linewidth=0.5,
                                                          color='black')
                                    axes_1d[i, j].axvline(x=confidence_bd(
                                        samples[cnt, :, k])[0],
                                                          linewidth=0.5,
                                                          color='b')
                                    axes_1d[i, j].axvline(x=confidence_bd(
                                        samples[cnt, :, k])[1],
                                                          linewidth=0.5,
                                                          color='b')
                                    axes_1d[i, j].axvline(x=confidence_bd(
                                        rev_x[:, k])[0],
                                                          linewidth=0.5,
                                                          color='r')
                                    axes_1d[i, j].axvline(x=confidence_bd(
                                        rev_x[:, k])[1],
                                                          linewidth=0.5,
                                                          color='r')
                                    axes_1d[i, j].set_xlim([0, 1])
                                    axes_1d[i, j].text(
                                        0.90,
                                        0.95,
                                        oltxt,
                                        horizontalalignment='right',
                                        verticalalignment='top',
                                        transform=axes_1d[i, j].transAxes)
                                    axes_1d[i, j].set_xlabel(
                                        parname1) if i == r - 1 else axes_1d[
                                            i, j].set_xlabel('')

                                    cnt += 1
                            # save the results to file
                            fig_1d.canvas.draw()
                            fig_1d.savefig('%sposteriors-1d_%d_%04d.png' %
                                           (out_dir, k, epoch),
                                           dpi=360)
                            fig_1d.savefig('%slatest-1d_%d.png' % (out_dir, k),
                                           dpi=360)
                            #fig_1d.close()

                            fig.canvas.draw()
                            fig.savefig('%sposteriors-2d_%d%d_%04d.png' %
                                        (out_dir, k, nextk, epoch),
                                        dpi=360)
                            fig.savefig('%slatest-2d_%d%d.png' %
                                        (out_dir, k, nextk),
                                        dpi=360)
                            #fig.close()
                s += 1

            # plot overlap results
            if np.remainder(epoch, plot_cadence) == 0:
                fig_log = plt.figure(figsize=(6, 6))
                axes_log = fig_log.add_subplot(1, 1, 1)
                for i in range(r):
                    for j in range(r):
                        axes_log.semilogx(np.arange(tot_epoch,
                                                    step=plot_cadence),
                                          olvec[i, j, :],
                                          alpha=0.5)
                axes_log.grid()
                axes_log.set_ylabel('overlap')
                axes_log.set_xlabel('epoch (log)')
                axes_log.set_ylim([0, 1])
                plt.savefig('%soverlap_logscale.png' % out_dir, dpi=360)
                plt.close()

                fig = plt.figure(figsize=(6, 6))
                axes = fig.add_subplot(1, 1, 1)
                for i in range(r):
                    for j in range(r):
                        axes.plot(np.arange(tot_epoch, step=plot_cadence),
                                  olvec[i, j, :],
                                  alpha=0.5)
                axes.grid()
                axes.set_ylabel('overlap')
                axes.set_xlabel('epoch')
                axes.set_ylim([0, 1])
                plt.savefig('%soverlap.png' % out_dir, dpi=360)
                plt.close()

            #egg = True
            #if egg==False:
            if np.remainder(epoch, plot_cadence) == 0 and (epoch > 5):
                fig, axis = plt.subplots(4, 1, figsize=(10, 8))
                #fig.show()
                fig.canvas.draw()
                axis[0].clear()
                axis[1].clear()
                axis[2].clear()
                axis[3].clear()
                for i in range(len(indLosses)):
                    lossVec[i].append(indLosses[i])
                losses.append(loss)
                fig.suptitle('Current Loss: %.2e, min loss: %.2e' %
                             (loss, np.nanmin(np.abs(losses))))
                axis[0].semilogy(np.arange(len(losses)), np.abs(losses))
                for i, lo in enumerate(lossVec):
                    axis[1].semilogy(np.arange(len(losses)),
                                     lo,
                                     '--',
                                     label=lossLabels[i])
                axis[1].legend(loc='upper left')
                tNow = time()
                elapsed = int(tNow - tStart)
                eta = int((tNow - tStart) /
                          (epoch + 1) * trainer.numEpochs) - elapsed

                if epoch % 2 == 0:
                    mses = trainer.test(maxBatches=1)
                    lineProfiles = mses[2]

                if epoch % 10 == 0:
                    alphaRange, mmdF, mmdB, idxF, idxB = trainer.review_mmd()

                axis[3].semilogx(alphaRange, mmdF, label='Latent Space')
                axis[3].semilogx(alphaRange, mmdB, label='Backward')
                axis[3].semilogx(alphaRange[idxF], mmdF[idxF], 'ro')
                axis[3].semilogx(alphaRange[idxB], mmdB[idxB], 'ro')
                axis[3].legend()

                testTime = time() - tNow
                axis[2].plot(
                    lineProfiles[0, model.outSchema.timeseries].cpu().numpy())
                for a in axis:
                    a.grid()
                axis[3].set_xlabel(
                    'Epochs: %d, Elapsed: %d s, ETA: %d s (Testing: %d s)' %
                    (epoch, elapsed, eta, testTime))

                fig.canvas.draw()
                fig.savefig('%slosses.pdf' % out_dir)

    except KeyboardInterrupt:
        pass
    finally:
        print("\n\nTraining took {(time()-tStart)/60:.2f} minutes\n")
Beispiel #10
0
def main():

    # Set up simulation parameters
    batch_size = 1600  # set batch size
    r = 3  # the grid dimension for the output tests
    test_split = r * r  # number of testing samples to use
    sig_model = 'sg'  # the signal model to use
    sigma = 0.2  # the noise std
    ndata = 32  # number of data samples
    bound = [0.0, 1.0, 0.0, 1.0]  # effective bound for likelihood
    seed = 1  # seed for generating data

    # generate data
    pos, labels, x, sig = data.generate(model=sig_model,
                                        tot_dataset_size=2**20,
                                        ndata=ndata,
                                        sigma=sigma,
                                        prior_bound=bound,
                                        seed=seed)

    # seperate the test data for plotting
    pos_test = pos[-test_split:]
    labels_test = labels[-test_split:]
    sig_test = sig[-test_split:]

    # plot the test data examples
    plt.figure(figsize=(6, 6))
    fig, axes = plt.subplots(r, r, figsize=(6, 6))
    cnt = 0
    for i in range(r):
        for j in range(r):
            axes[i, j].plot(x, np.array(labels_test[cnt, :]), '.')
            axes[i, j].plot(x, np.array(sig_test[cnt, :]), '-')
            cnt += 1
            axes[i, j].axis([0, 1, -1.5, 1.5])
    plt.savefig('/data/public_html/chrism/FrEIA/test_distribution.png',
                dpi=360)
    plt.close()

    # setting up the model
    ndim_x = 2  # number of posterior parameter dimensions (x,y)
    ndim_y = ndata  # number of label dimensions (noisy data samples)
    ndim_z = 8  # number of latent space dimensions?
    ndim_tot = max(ndim_x,
                   ndim_y + ndim_z)  # must be > ndim_x and > ndim_y + ndim_z

    # define different parts of the network
    # define input node
    inp = InputNode(ndim_tot, name='input')

    # define hidden layer nodes
    t1 = Node([inp.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.2
        }
    })

    t2 = Node([t1.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.2
        }
    })

    t3 = Node([t2.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.2
        }
    })

    # define output layer node
    outp = OutputNode([t3.out0], name='output')

    nodes = [inp, t1, t2, t3, outp]
    model = ReversibleGraphNet(nodes)

    # Train model
    # Training parameters
    n_epochs = 1000
    meta_epoch = 12  # what is this???
    n_its_per_epoch = 12
    batch_size = 1600

    lr = 1e-2
    gamma = 0.01**(1. / 120)
    l2_reg = 2e-5

    y_noise_scale = 3e-2
    zeros_noise_scale = 3e-2

    # relative weighting of losses:
    lambd_predict = 300.  # forward pass
    lambd_latent = 300.  # laten space
    lambd_rev = 400.  # backwards pass

    # padding both the data and the latent space
    # such that they have equal dimension to the parameter space
    #pad_x = torch.zeros(batch_size, ndim_tot - ndim_x)
    #pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=(0.8, 0.8),
                                 eps=1e-04,
                                 weight_decay=l2_reg)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=meta_epoch,
                                                gamma=gamma)

    # define the three loss functions
    loss_backward = MMD_multiscale
    loss_latent = MMD_multiscale
    loss_fit = fit

    # set up training set data loader
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(
        pos[test_split:], labels[test_split:]),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)

    # initialisation of network weights
    for mod_list in model.children():
        for block in mod_list.children():
            for coeff in block.children():
                coeff.fc3.weight.data = 0.01 * torch.randn(
                    coeff.fc3.weight.shape)
    model.to(device)

    # initialize plot for showing testing results
    fig, axes = plt.subplots(r, r, figsize=(6, 6))

    # number of test samples to use after training
    N_samp = 256

    # precompute true likelihood on the test data
    Ngrid = 64
    cnt = 0
    lik = np.zeros((r, r, Ngrid * Ngrid))
    for i in range(r):
        for j in range(r):
            mvec, cvec, temp = data.get_lik(np.array(
                labels_test[cnt, :]).flatten(),
                                            n_grid=Ngrid,
                                            sig_model=sig_model,
                                            sigma=sigma,
                                            xvec=x,
                                            bound=bound)
            lik[i, j, :] = temp.flatten()
            cnt += 1

    # start training loop
    try:
        t_start = time()
        # loop over number of epochs
        for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80):

            scheduler.step()

            # Initially, the l2 reg. on x and z can give huge gradients, set
            # the lr lower for this
            if i_epoch < 0:
                print('inside this iepoch<0 thing')
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr * 1e-2

            # train the model
            train(model, train_loader, n_its_per_epoch, zeros_noise_scale,
                  batch_size, ndim_tot, ndim_x, ndim_y, ndim_z, y_noise_scale,
                  optimizer, lambd_predict, loss_fit, lambd_latent,
                  loss_latent, lambd_rev, loss_backward, i_epoch)

            # loop over a few cases and plot results in a grid
            cnt = 0
            for i in range(r):
                for j in range(r):

                    # convert data into correct format
                    y_samps = np.tile(np.array(labels_test[cnt, :]),
                                      N_samp).reshape(N_samp, ndim_y)
                    y_samps = torch.tensor(y_samps, dtype=torch.float)
                    #y_samps += y_noise_scale * torch.randn(N_samp, ndim_y)
                    y_samps = torch.cat(
                        [
                            torch.randn(N_samp, ndim_z),  #zeros_noise_scale * 
                            torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z),
                            y_samps
                        ],
                        dim=1)
                    y_samps = y_samps.to(device)

                    # use the network to predict parameters
                    rev_x = model(y_samps, rev=True)
                    rev_x = rev_x.cpu().data.numpy()

                    # plot the samples and the true contours
                    axes[i, j].clear()
                    axes[i, j].contour(mvec,
                                       cvec,
                                       lik[i, j, :].reshape(Ngrid, Ngrid),
                                       levels=[0.68, 0.9, 0.99])
                    axes[i, j].scatter(rev_x[:, 0],
                                       rev_x[:, 1],
                                       s=0.5,
                                       alpha=0.5)
                    axes[i, j].plot(pos_test[cnt, 0],
                                    pos_test[cnt, 1],
                                    '+r',
                                    markersize=8)
                    axes[i, j].axis(bound)

                    cnt += 1

            # sve the results to file
            fig.canvas.draw()
            plt.savefig('/data/public_html/chrism/FrEIA/posteriors_%s.png' %
                        i_epoch,
                        dpi=360)
            plt.savefig('/data/public_html/chrism/FrEIA/latest.png', dpi=360)

    except KeyboardInterrupt:
        pass
    finally:
        print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")