def main(): # Set up simulation parameters batch_size = 1600 # set batch size r = 4 # the grid dimension for the output tests test_split = r*r # number of testing samples to use sigma = 0.2 # the noise std ndata = 64 # number of data samples usepars = [0,1,2,3] # parameter indices to use seed = 1 # seed for generating data run_label='gpu0' out_dir = "/home/hunter.gabbard/public_html/CBC/cINNamon/gausian_results/multipar/%s/" % run_label # generate data pos, labels, x, sig, parnames = data.generate( tot_dataset_size=2**20, ndata=ndata, usepars=usepars, sigma=sigma, seed=seed ) print('generated data') # seperate the test data for plotting pos_test = pos[-test_split:] labels_test = labels[-test_split:] sig_test = sig[-test_split:] # plot the test data examples plt.figure(figsize=(6,6)) fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row') cnt = 0 for i in range(r): for j in range(r): axes[i,j].plot(x,np.array(labels_test[cnt,:]),'.') axes[i,j].plot(x,np.array(sig_test[cnt,:]),'-') cnt += 1 axes[i,j].axis([0,1,-1.5,1.5]) axes[i,j].set_xlabel('time') if i==r-1 else axes[i,j].set_xlabel('') axes[i,j].set_ylabel('h(t)') if j==0 else axes[i,j].set_ylabel('') plt.savefig('%stest_distribution.png' % out_dir,dpi=360) plt.close() # precompute true posterior samples on the test data cnt = 0 N_samp = 1000 ndim_x = len(usepars) samples = np.zeros((r*r,N_samp,ndim_x)) for i in range(r): for j in range(r): samples[cnt,:,:] = data.get_lik(np.array(labels_test[cnt,:]).flatten(),sigma=sigma,usepars=usepars,Nsamp=N_samp) print(samples[cnt,:10,:]) cnt += 1 # initialize plot for showing testing results fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row') for k in range(ndim_x): parname1 = parnames[k] for nextk in range(ndim_x): parname2 = parnames[nextk] if nextk>k: cnt = 0 for i in range(r): for j in range(r): # plot the samples and the true contours axes[i,j].clear() axes[i,j].scatter(samples[cnt,:,k], samples[cnt,:,nextk],c='b',s=0.5,alpha=0.5) axes[i,j].plot(pos_test[cnt,k],pos_test[cnt,nextk],'+c',markersize=8) axes[i,j].set_xlim([0,1]) axes[i,j].set_ylim([0,1]) axes[i,j].set_xlabel(parname1) if i==r-1 else axes[i,j].set_xlabel('') axes[i,j].set_ylabel(parname2) if j==0 else axes[i,j].set_ylabel('') cnt += 1 # save the results to file fig.canvas.draw() plt.savefig('%strue_samples_%d%d.png' % (out_dir,k,nextk),dpi=360) # setting up the model ndim_x = len(usepars) # number of posterior parameter dimensions (x,y) ndim_y = ndata # number of label dimensions (noisy data samples) ndim_z = 4 # number of latent space dimensions? ndim_tot = max(ndim_x,ndim_y+ndim_z) # must be > ndim_x and > ndim_y + ndim_z # define different parts of the network # define input node inp = InputNode(ndim_tot, name='input') # define hidden layer nodes t1 = Node([inp.out0], rev_multiplicative_layer, {'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': {'dropout': 0.2}}) #t1 = Node([inp.out0], rev_multiplicative_layer, # {'F_class': F_conv, 'clamp': 2.0, # 'F_args': {'kernel_size': 3,'leaky_slope': 0.1}}) #def __init__(self, dims_in, F_class=F_fully_connected, F_args={}, # clamp=5.): # super(rev_multiplicative_layer, self).__init__() # channels = dims_in[0][0] # # self.split_len1 = channels // 2 # self.split_len2 = channels - channels // 2 # self.ndims = len(dims_in[0]) # # self.clamp = clamp # self.max_s = exp(clamp) # self.min_s = exp(-clamp) # # self.s1 = F_class(self.split_len1, self.split_len2, **F_args) # self.t1 = F_class(self.split_len1, self.split_len2, **F_args) # self.s2 = F_class(self.split_len2, self.split_len1, **F_args) # self.t2 = F_class(self.split_len2, self.split_len1, **F_args) t2 = Node([t1.out0], rev_multiplicative_layer, {'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': {'dropout': 0.2}}) t3 = Node([t2.out0], rev_multiplicative_layer, {'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': {'dropout': 0.2}}) t4 = Node([t3.out0], rev_multiplicative_layer, {'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': {'dropout': 0.0}}) # define output layer node outp = OutputNode([t4.out0], name='output') nodes = [inp, t1, t2, t3, t4, outp] model = ReversibleGraphNet(nodes) # Train model # Training parameters n_epochs = 10000 meta_epoch = 12 # what is this??? n_its_per_epoch = 12 batch_size = 1600 lr = 1e-2 gamma = 0.01**(1./120) l2_reg = 2e-5 y_noise_scale = 3e-2 zeros_noise_scale = 3e-2 # relative weighting of losses: lambd_predict = 300. # forward pass lambd_latent = 300. # laten space lambd_rev = 400. # backwards pass # padding both the data and the latent space # such that they have equal dimension to the parameter space #pad_x = torch.zeros(batch_size, ndim_tot - ndim_x) #pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z) # define optimizer optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.8, 0.8), eps=1e-04, weight_decay=l2_reg) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=meta_epoch, gamma=gamma) # define the three loss functions loss_backward = MMD_multiscale loss_latent = MMD_multiscale loss_fit = fit # set up training set data loader train_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(pos[test_split:], labels[test_split:]), batch_size=batch_size, shuffle=True, drop_last=True) # initialisation of network weights #for mod_list in model.children(): # for block in mod_list.children(): # for coeff in block.children(): # coeff.fc3.weight.data = 0.01*torch.randn(coeff.fc3.weight.shape) #model.to(device) # start training loop try: t_start = time() olvec = np.zeros((r,r,int(n_epochs/10))) s = 0 # loop over number of epochs for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80): scheduler.step() # Initially, the l2 reg. on x and z can give huge gradients, set # the lr lower for this if i_epoch < 0: print('inside this iepoch<0 thing') for param_group in optimizer.param_groups: param_group['lr'] = lr * 1e-2 # train the model train(model,train_loader,n_its_per_epoch,zeros_noise_scale,batch_size, ndim_tot,ndim_x,ndim_y,ndim_z,y_noise_scale,optimizer,lambd_predict, loss_fit,lambd_latent,loss_latent,lambd_rev,loss_backward,i_epoch) # loop over a few cases and plot results in a grid if np.remainder(i_epoch,10)==0: for k in range(ndim_x): parname1 = parnames[k] for nextk in range(ndim_x): parname2 = parnames[nextk] if nextk>k: cnt = 0 # initialize plot for showing testing results fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row') for i in range(r): for j in range(r): # convert data into correct format y_samps = np.tile(np.array(labels_test[cnt,:]),N_samp).reshape(N_samp,ndim_y) y_samps = torch.tensor(y_samps, dtype=torch.float) y_samps += y_noise_scale * torch.randn(N_samp, ndim_y) y_samps = torch.cat([torch.randn(N_samp, ndim_z), zeros_noise_scale * torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z), y_samps], dim=1) y_samps = y_samps.to(device) # use the network to predict parameters rev_x = model(y_samps, rev=True) rev_x = rev_x.cpu().data.numpy() # compute the n-d overlap if k==0 and nextk==1: ol = data.overlap(samples[cnt,:,:ndim_x],rev_x[:,:ndim_x]) olvec[i,j,s] = ol # plot the samples and the true contours axes[i,j].clear() axes[i,j].scatter(samples[cnt,:,k], samples[cnt,:,nextk],c='b',s=0.2,alpha=0.5) axes[i,j].scatter(rev_x[:,k], rev_x[:,nextk],c='r',s=0.2,alpha=0.5) axes[i,j].plot(pos_test[cnt,k],pos_test[cnt,nextk],'+c',markersize=8) axes[i,j].set_xlim([0,1]) axes[i,j].set_ylim([0,1]) oltxt = '%.2f' % olvec[i,j,s] axes[i,j].text(0.90, 0.95, oltxt, horizontalalignment='right', verticalalignment='top', transform=axes[i,j].transAxes) matplotlib.rc('xtick', labelsize=8) matplotlib.rc('ytick', labelsize=8) axes[i,j].set_xlabel(parname1) if i==r-1 else axes[i,j].set_xlabel('') axes[i,j].set_ylabel(parname2) if j==0 else axes[i,j].set_ylabel('') cnt += 1 # save the results to file fig.canvas.draw() plt.savefig('%sposteriors_%d%d_%04d.png' % (out_dir,k,nextk,i_epoch),dpi=360) plt.savefig('%slatest_%d%d.png' % (out_dir,k,nextk),dpi=360) plt.close() s += 1 # plot overlap results if np.remainder(i_epoch,10)==0: fig, axes = plt.subplots(1,figsize=(6,6)) for i in range(r): for j in range(r): axes.semilogx(10*np.arange(olvec.shape[2]),olvec[i,j,:],alpha=0.5) axes.grid() axes.set_ylabel('overlap') axes.set_xlabel('epoch') axes.set_ylim([0,1]) plt.savefig('%soverlap.png' % out_dir,dpi=360) plt.close() except KeyboardInterrupt: pass finally: print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")
def plot_y_test(model,Nsamp,usepars,sigma,ndim_x,ndim_y,ndim_z,ndim_tot,outdir,r,i_epoch,conv=False,model_f=None,model_r=None,do_double_nn=False,do_cnn=False): """ Plot examples of test y-data generation """ # generate test data x_test, y_test, x, sig_test, parnames = data_maker.generate( tot_dataset_size=Nsamp, ndata=ndim_y, usepars=usepars, sigma=sigma, seed=1 ) out_shape = [-1,ndim_tot] if conv==True: in_shape = [-1,1,ndim_tot] else: in_shape = [-1,ndim_tot] fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row') # run the x test data through the model x = torch.tensor(x_test[:r*r,:],dtype=torch.float,device=dev).clone().detach() y_test = torch.tensor(y_test[:r*r,:],dtype=torch.float,device=dev).clone().detach() sig_test = torch.tensor(sig_test[:r*r,:],dtype=torch.float,device=dev).clone().detach() # make the new padding for the noisy data and latent vector data pad_x = torch.zeros(r*r,ndim_tot-ndim_x-ndim_y,device=dev) # make a padded zy vector (with all new noise) x_padded = torch.cat((x,pad_x,y_test-sig_test),dim=1) # apply forward model to the x data if do_double_nn: if do_cnn: data = torch.cat((x,y_test-sig_test), dim=1) output = model_f(data.reshape(data.shape[0],1,data.shape[1]))#.reshape(out_shape) output_y = output[:,:ndim_y] # extract the model output y else: output = model_f(torch.cat((x,y_test-sig_test), dim=1))#.reshape(out_shape) output_y = output[:,:ndim_y] # extract the model output y else: output = model(x_padded.reshape(in_shape))#.reshape(out_shape) output_y = output[:,model.outSchema.timeseries] # extract the model output y y = output_y.cpu().data.numpy() cnt = 0 for i in range(r): for j in range(r): axes[i,j].clear() axes[i,j].plot(np.arange(ndim_y)/float(ndim_y),y[cnt,:],'b-') axes[i,j].plot(np.arange(ndim_y)/float(ndim_y),y_test[cnt,:].cpu().data.numpy(),'k',alpha=0.5) axes[i,j].set_xlim([0,1]) #matplotlib.rc('xtick', labelsize=5) #matplotlib.rc('ytick', labelsize=5) axes[i,j].set_xlabel('t') if i==r-1 else axes[i,j].set_xlabel('') axes[i,j].set_ylabel('y') if j==0 else axes[i,j].set_ylabel('') if i==0 and j==0: axes[i,j].legend(('pred y','y')) cnt += 1 fig.canvas.draw() fig.savefig('%s/ytest_%04d.png' % (outdir,i_epoch),dpi=360) fig.savefig('%s/latest/latest_ytest.png' % outdir,dpi=360) plt.close(fig) return
def main(): ## If file exists, delete it ## if os.path.exists(out_dir): shutil.rmtree(out_dir) else: ## Show a message ## print("Attention: %s file not found" % out_dir) # setup output directory - if it does not exist os.makedirs('%s' % out_dir) os.makedirs('%s/latest' % out_dir) os.makedirs('%s/animations' % out_dir) # generate data if not load_dataset: pos, labels, x, sig, parnames = data_maker.generate( tot_dataset_size=tot_dataset_size, ndata=ndata, usepars=usepars, sigma=sigma, seed=seed ) print('generated data') hf = h5py.File('benchmark_data_%s.h5py' % run_label, 'w') hf.create_dataset('pos', data=pos) hf.create_dataset('labels', data=labels) hf.create_dataset('x', data=x) hf.create_dataset('sig', data=sig) hf.create_dataset('parnames', data=np.string_(parnames)) data = AtmosData([dataLocation1], test_split, resampleWl=None) data.split_data_and_init_loaders(batchsize) # seperate the test data for plotting pos_test = data.pos_test labels_test = data.labels_test sig_test = data.sig_test ndim_x = len(usepars) print('Computing MCMC posterior samples') if do_mcmc or not load_dataset: # precompute true posterior samples on the test data cnt = 0 samples = np.zeros((r*r,N_samp,ndim_x)) for i in range(r): for j in range(r): samples[cnt,:,:] = data_maker.get_lik(np.array(labels_test[cnt,:]).flatten(),sigma=sigma,usepars=usepars,Nsamp=N_samp) print(samples[cnt,:10,:]) cnt += 1 # save computationaly expensive mcmc/waveform runs if load_dataset==True: # define names of parameters to have PE done on them parnames=['A','t0','tau','phi','w'] names = [parnames[int(i)] for i in usepars] hf = h5py.File('benchmark_data_%s.h5py' % run_label, 'w') hf.create_dataset('pos', data=data.pos) hf.create_dataset('labels', data=data.labels) hf.create_dataset('x', data=data.x) hf.create_dataset('sig', data=data.sig) hf.create_dataset('parnames', data=np.string_(names)) hf.create_dataset('samples', data=np.string_(samples)) hf.close() else: samples=h5py.File(dataLocation1, 'r')['samples'][:] parnames=h5py.File(dataLocation1, 'r')['parnames'][:] # plot the test data examples plt.figure(figsize=(6,6)) fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row') cnt = 0 for i in range(r): for j in range(r): axes[i,j].plot(data.x,np.array(labels_test[cnt,:]),'-', label='noisy') axes[i,j].plot(data.x,np.array(sig_test[cnt,:]),'-', label='noise-free') axes[i,j].legend(loc='upper left') cnt += 1 axes[i,j].axis([0,1,-1.5,1.5]) axes[i,j].set_xlabel('time') if i==r-1 else axes[i,j].set_xlabel('') axes[i,j].set_ylabel('h(t)') if j==0 else axes[i,j].set_ylabel('') plt.savefig('%s/test_distribution.png' % out_dir,dpi=360) plt.close() # initialize plot for showing testing results fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row') for k in range(ndim_x): parname1 = parnames[k] for nextk in range(ndim_x): parname2 = parnames[nextk] if nextk>k: cnt = 0 for i in range(r): for j in range(r): # plot the samples and the true contours axes[i,j].clear() axes[i,j].scatter(samples[cnt,:,k], samples[cnt,:,nextk],c='b',s=0.5,alpha=0.5, label='MCMC') axes[i,j].plot(pos_test[cnt,k],pos_test[cnt,nextk],'+c',markersize=8, label='MCMC Truth') axes[i,j].set_xlim([0,1]) axes[i,j].set_ylim([0,1]) axes[i,j].set_xlabel(parname1) if i==r-1 else axes[i,j].set_xlabel('') axes[i,j].set_ylabel(parname2) if j==0 else axes[i,j].set_ylabel('') axes[i,j].legend(loc='upper left') cnt += 1 # save the results to file fig.canvas.draw() plt.savefig('%s/true_samples_%d%d.png' % (out_dir,k,nextk),dpi=360) def store_pars(f,pars): for i in pars.keys(): f.write("%s: %s\n" % (i,str(pars[i]))) f.close() # store hyperparameters for posterity f=open("%s_run-pars.txt" % run_label,"w+") pars_to_store={"sigma":sigma,"ndata":ndata,"T":T,"seed":seed,"n_neurons":n_neurons,"bound":bound,"conv_nn":conv_nn,"filtsize":filtsize,"dropout":dropout, "clamp":clamp,"ndim_z":ndim_z,"tot_epoch":tot_epoch,"lr":lr, "latentAlphas":latentAlphas, "backwardAlphas":backwardAlphas, "zerosNoiseScale":zerosNoiseScale,"wPred":wPred,"wLatent":wLatent,"wRev":wRev,"tot_dataset_size":tot_dataset_size, "numInvLayers":numInvLayers,"batchsize":batchsize} store_pars(f,pars_to_store) if extra_z: inRepr = [('amp', 1), ('t0', 1), ('tau', 1), ('!!PAD',), ('yNoise', data.atmosOut.shape[1])] else: inRepr = [('amp', 1), ('t0', 1), ('tau', 1), ('!!PAD',)] outRepr = [('LatentSpace', ndim_z), ('!!PAD',), ('timeseries', data.atmosOut.shape[1])] if do_double_nn: model_f = nn_double_f((ndim_x+ndim_y),(ndim_y+ndim_z)) model_r = nn_double_r((ndim_y+ndim_z),(ndim_x+ndim_y)) else: model = RadynversionNet(inRepr, outRepr, dropout=dropout, zeroPadding=0, minSize=ndim_tot, numInvLayers=numInvLayers) # Construct the class that trains the model, the initial weighting between the losses, learning rate, and the initial number of epochs to train for. # load previous model if asked if load_model: model.load_state_dict(torch.load('models/gpu5_model.pt')) #% run_label)) if do_double_nn: trainer = DoubleNetTrainer(model_f, model_r, data, dev, load_model=load_model) trainer.training_params(tot_epoch, lr=lr, fadeIn=fadeIn, loss_latent=Loss.mmd_multiscale_on(dev, alphas=latentAlphas), loss_fit=Loss.mse,ndata=ndata,sigma=sigma,seed=seed,batchSize=batchsize,usepars=usepars) else: trainer = RadynversionTrainer(model, data, dev, load_model=load_model) trainer.training_params(tot_epoch, lr=lr, fadeIn=fadeIn, zerosNoiseScale=zerosNoiseScale, wPred=wPred, wLatent=wLatent, wRev=wRev, loss_latent=Loss.mmd_multiscale_on(dev, alphas=latentAlphas), loss_backward=Loss.mmd_multiscale_on(dev, alphas=backwardAlphas), loss_fit=Loss.mse,ndata=ndata,sigma=sigma,seed=seed,n_neurons=n_neurons,batchSize=batchsize,usepars=usepars, y_noise_scale=y_noise_scale) totalEpochs = 0 # Train the model for these first epochs with a nice graph that updates during training. losses = [] wRevScale_tot = [] beta_score_hist=[] beta_score_loop_hist=[] lossVec = [[] for _ in range(4)] lossLabels = ['L2 Line', 'MMD Latent', 'MMD Reverse', 'L2 Reverse'] out = None alphaRange, mmdF, mmdB, idxF, idxB = [1,1], [1,1], [1,1], 0, 0 try: tStart = time() olvec = np.zeros((r,r,int(trainer.numEpochs/plot_cadence))) adksVec = np.zeros((r,r,ndim_x,4,int(trainer.numEpochs/plot_cadence))) s = 0 for epoch in range(trainer.numEpochs): print('Epoch %s/%s' % (str(epoch),str(trainer.numEpochs))) totalEpochs += 1 if do_double_nn: trainer.scheduler_f.step() loss, indLosses = trainer.train(epoch,gen_inf_temp=gen_inf_temp,extra_z=extra_z, do_cnn=do_cnn) else: trainer.scheduler.step() loss, indLosses = trainer.train(epoch,gen_inf_temp=gen_inf_temp,extra_z=extra_z,do_covar=do_covar) if do_double_nn: # save trained model torch.save(model_f.state_dict(), 'models/%s_model_f.pt' % run_label) torch.save(model_r.state_dict(), 'models/%s_model_r.pt' % run_label) else: # save trained model torch.save(model.state_dict(), 'models/%s_model.pt' % run_label) # loop over a few cases and plot results in a grid if np.remainder(epoch,plot_cadence)==0: for k in range(ndim_x): parname1 = parnames[k] for nextk in range(ndim_x): parname2 = parnames[nextk] if nextk>k: cnt = 0 # initialize plot for showing testing results fig, axes = plt.subplots(r,r,figsize=(6,6),sharex='col',sharey='row') for i in range(r): for j in range(r): # convert data into correct format y_samps = np.tile(np.array(labels_test[cnt,:]),N_samp).reshape(N_samp,ndim_y) y_samps = torch.tensor(y_samps, dtype=torch.float) # add noise to y data (why?) #y = y_samps + y_noise_scale * torch.randn(N_samp, ndim_y, device=dev) if do_double_nn: y_samps = torch.cat([torch.randn(N_samp, ndim_z), y_samps], dim=1) else: y_samps = torch.cat([torch.randn(N_samp, ndim_z), zerosNoiseScale * torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z), y_samps], dim=1) y_samps = y_samps.to(dev) # use the network to predict parameters if do_double_nn: if do_cnn: y_samps = y_samps.reshape(y_samps.shape[0],1,y_samps.shape[1]) rev_x = model_r(y_samps) else: rev_x = model_r(y_samps) else: rev_x = model(y_samps, rev=True) rev_x = rev_x.cpu().data.numpy() # compute the n-d overlap if k==0 and nextk==1: ol = data_maker.overlap(samples[cnt,:,:ndim_x],rev_x[:,:ndim_x]) olvec[i,j,s] = ol # print A-D and K-S test ks_mcmc_arr, ks_inn_arr, ad_mcmc_arr, ad_inn_arr = result_stat_tests(rev_x[:, usepars], samples[cnt,:,:ndim_x], cnt, parnames) for p in usepars: for c in range(4): adksVec[i,j,p,c,s] = np.array([ks_mcmc_arr,ks_inn_arr,ad_mcmc_arr,ad_inn_arr])[c,p] # plot the samples and the true contours axes[i,j].clear() if latent==True: colors = z.cpu().detach().numpy() colors = np.linalg.norm(colors,axis=1) axes[i,j].scatter(rev_x[:,k], rev_x[:,nextk],c=colors,s=1.0,cmap='hsv',alpha=0.75, label='INN') else: axes[i,j].scatter(samples[cnt,:,k], samples[cnt,:,nextk],c='b',s=0.2,alpha=0.5, label='MCMC') axes[i,j].scatter(rev_x[:,k], rev_x[:,nextk],c='r',s=0.2,alpha=0.5, label='INN') axes[i,j].set_xlim([0,1]) axes[i,j].set_ylim([0,1]) axes[i,j].plot(pos_test[cnt,k],pos_test[cnt,nextk],'+c',markersize=8, label='Truth') oltxt = '%.2f' % olvec[i,j,s] axes[i,j].text(0.90, 0.95, oltxt, horizontalalignment='right', verticalalignment='top', transform=axes[i,j].transAxes) matplotlib.rc('xtick', labelsize=8) matplotlib.rc('ytick', labelsize=8) axes[i,j].set_xlabel(parname1) if i==r-1 else axes[i,j].set_xlabel('') axes[i,j].set_ylabel(parname2) if j==0 else axes[i,j].set_ylabel('') if i == 0 and j == 0: axes[i,j].legend(loc='upper left', fontsize='x-small') cnt += 1 # save the results to file fig.canvas.draw() if latent==True: plt.savefig('%s/latent_map_%d%d_%04d.png' % (out_dir,k,nextk,epoch),dpi=360) plt.savefig('%s/latest/latent_map_%d%d.png' % (out_dir,k,nextk),dpi=360) else: plt.savefig('%s/posteriors_%d%d_%04d.png' % (out_dir,k,nextk,epoch),dpi=360) plt.savefig('%s/latest/posteriors_%d%d.png' % (out_dir,k,nextk),dpi=360) plt.close(fig) s += 1 # plot overlap results if np.remainder(epoch,plot_cadence)==0: fig, axes = plt.subplots(1,figsize=(6,6)) for i in range(r): for j in range(r): color = next(axes._get_lines.prop_cycler)['color'] axes.semilogx(np.arange(epoch, step=plot_cadence),olvec[i,j,:int((epoch)/plot_cadence)],alpha=0.5, color=color) axes.plot([int(epoch)],[olvec[i,j,int(epoch/plot_cadence)]],'.', color=color) axes.grid() axes.set_ylabel('overlap') axes.set_xlabel('epoch') axes.set_ylim([0,1]) plt.savefig('%s/latest/overlap_logscale.png' % out_dir, dpi=360) plt.close(fig) fig, axes = plt.subplots(1,figsize=(6,6)) for i in range(r): for j in range(r): color = next(axes._get_lines.prop_cycler)['color'] axes.plot(np.arange(epoch, step=plot_cadence),olvec[i,j,:int((epoch)/plot_cadence)],alpha=0.5, color=color) axes.plot([int(epoch)],[olvec[i,j,int(epoch/plot_cadence)]],'.', color=color) axes.grid() axes.set_ylabel('overlap') axes.set_xlabel('epoch') axes.set_ylim([0,1]) plt.savefig('%s/latest/overlap.png' % out_dir, dpi=360) plt.close(fig) if do_double_nn: # plot predicted time series vs. actually time series examples model=None plot_y_test(model,N_samp,usepars,sigma,ndim_x,ndim_y,ndim_z,ndim_tot,out_dir,r,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=do_double_nn,do_cnn=do_cnn) # make y_dist_plot plot_y_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,out_dir,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=do_double_nn,do_cnn=do_cnn) plot_x_evolution(model,ndim_x,ndim_y,ndim_z,ndim_tot,sigma,parnames,out_dir,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=do_double_nn,do_cnn=do_cnn) plot_z_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,out_dir,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=do_double_nn,do_cnn=do_cnn) else: # plot predicted time series vs. actually time series examples plot_y_test(model,N_samp,usepars,sigma,ndim_x,ndim_y,ndim_z,ndim_tot,out_dir,r,epoch,conv=False) # make y_dist_plot plot_y_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,out_dir,epoch,conv=False) plot_x_evolution(model,ndim_x,ndim_y,ndim_z,ndim_tot,sigma,parnames,out_dir,epoch,conv=False) plot_z_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,out_dir,epoch,conv=False) # plot evolution of y if not do_double_nn: for c in range(ndim_x): plot_y_evolution(model,c,parnames,ndim_x,ndim_y,ndim_z,ndim_tot,out_dir,epoch,conv=False,model_f=model_f,model_r=model_r,do_double_nn=True,do_cnn=do_cnn) # plot ad and ks results [ks_mcmc_arr,ks_inn_arr,ad_mcmc_arr,ad_inn_arr] for p in range(ndim_x): fig_ks, axis_ks = plt.subplots(1,figsize=(6,6)) fig_ad, axis_ad = plt.subplots(1,figsize=(6,6)) for i in range(r): for j in range(r): color_ks = next(axis_ks._get_lines.prop_cycler)['color'] axis_ks.semilogx(np.arange(tot_epoch, step=plot_cadence),adksVec[i,j,p,0,:],'--',alpha=0.5,color=color_ks) axis_ks.semilogx(np.arange(tot_epoch, step=plot_cadence),adksVec[i,j,p,1,:],alpha=0.5,color=color_ks) axis_ks.plot([int(epoch)],[adksVec[i,j,p,1,int(epoch/plot_cadence)]],'.', color=color_ks) axis_ks.set_yscale('log') color_ad = next(axis_ad._get_lines.prop_cycler)['color'] axis_ad.semilogx(np.arange(tot_epoch, step=plot_cadence),adksVec[i,j,p,2,:],'--',alpha=0.5,color=color_ad) axis_ad.semilogx(np.arange(tot_epoch, step=plot_cadence),adksVec[i,j,p,3,:],alpha=0.5,color=color_ad) axis_ad.plot([int(epoch)],[adksVec[i,j,p,3,int(epoch/plot_cadence)]],'.',color=color_ad) axis_ad.set_yscale('log') axis_ks.set_xlabel('Epoch') axis_ad.set_xlabel('Epoch') axis_ks.set_ylabel('KS Statistic') axis_ad.set_ylabel('AD Statistic') fig_ks.savefig('%s/latest/ks_%s_stat.png' % (out_dir,parnames[p]), dpi=360) fig_ad.savefig('%s/latest/ad_%s_stat.png' % (out_dir,parnames[p]), dpi=360) plt.close(fig_ks) plt.close(fig_ad) if ((epoch % 10 == 0) & (epoch>5)): #fig, axis = plt.subplots(4,1, figsize=(10,8)) #fig.canvas.draw() #axis[0].clear() #axis[1].clear() #axis[2].clear() #axis[3].clear() for i in range(len(indLosses)): lossVec[i].append(indLosses[i]) losses.append(loss) #fig.suptitle('Current Loss: %.2e, min loss: %.2e' % (loss, np.nanmin(np.abs(losses)))) #axis[0].semilogy(np.arange(len(losses)), np.abs(losses)) #for i, lo in enumerate(lossVec): # axis[1].semilogy(np.arange(len(losses)), lo, '--', label=lossLabels[i]) #axis[1].legend(loc='upper left') #tNow = time() #elapsed = int(tNow - tStart) #eta = int((tNow - tStart) / (epoch + 1) * trainer.numEpochs) - elapsed #if epoch % 2 == 0: # mses = trainer.test(samples,maxBatches=1,extra_z=extra_z) # lineProfiles = mses[2] if epoch % 10 == 0 and review_mmd and epoch >=600: print('Reviewing alphas') alphaRange, mmdF, mmdB, idxF, idxB = trainer.review_mmd() #axis[3].semilogx(alphaRange, mmdF, label='Latent Space') #axis[3].semilogx(alphaRange, mmdB, label='Backward') #axis[3].semilogx(alphaRange[idxF], mmdF[idxF], 'ro') #axis[3].semilogx(alphaRange[idxB], mmdB[idxB], 'ro') #axis[3].legend() #testTime = time() - tNow #axis[2].plot(lineProfiles[0, model.outSchema.timeseries].cpu().numpy()) #for a in axis: # a.grid() #axis[3].set_xlabel('Epochs: %d, Elapsed: %d s, ETA: %d s (Testing: %d s)' % (epoch, elapsed, eta, testTime)) #fig.canvas.draw() #fig.savefig('%slosses-wave-tot.pdf' % out_dir) #plt.close(fig) # make latent space plots """ if epoch % plot_cadence == 0: labels_z = [] for lab_idx in range(ndim_z): labels_z.append(r"latent%d" % lab_idx) fig_latent = corner.corner(lineProfiles[:, model.outSchema.LatentSpace].cpu().numpy(), plot_contours=False, labels=labels_z) fig_latent.savefig('%s/latest/latent_space.pdf' % out_dir) print('Plotted latent space') plt.close(fig_latent) """ # make non-logscale loss plot fig_loss, axes_loss = plt.subplots(1,figsize=(10,8)) wRevScale_tot.append(trainer.wRevScale) axes_loss.grid() axes_loss.set_ylabel('Loss') axes_loss.set_xlabel('Epochs elapsed: %s' % epoch) axes_loss.semilogy(np.arange(len(losses)), np.abs(losses), label='Total') for i, lo in enumerate(lossVec): axes_loss.semilogy(np.arange(len(losses)), lo, label=lossLabels[i]) axes_loss.semilogy(np.arange(len(losses)), wRevScale_tot, label='fadeIn') axes_loss.legend(loc='upper left') plt.savefig('%s/latest/losses.png' % out_dir) plt.close(fig) # make log scale loss plot fig_loss, axes_loss = plt.subplots(1,figsize=(10,8)) axes_loss.grid() axes_loss.set_ylabel('Loss') axes_loss.set_xlabel('Epochs elapsed: %s' % epoch) axes_loss.plot(np.arange(len(losses)), np.abs(losses), label='Total') for i, lo in enumerate(lossVec): axes_loss.plot(np.arange(len(losses)), lo, label=lossLabels[i]) axes_loss.plot(np.arange(len(losses)), wRevScale_tot, label='fadeIn') axes_loss.set_xscale('log') axes_loss.set_yscale('log') axes_loss.legend(loc='upper left') plt.savefig('%s/latest/losses_logscale.png' % out_dir) plt.close(fig) except KeyboardInterrupt: pass finally: print("\n\nTraining took {(time()-tStart)/60:.2f} minutes\n")
def plot_z_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,outdir,i_epoch,conv=False,model_f=None,model_r=None,do_double_nn=False,do_cnn=False): """ Plots the distribution of latent z variables """ Nsamp = 250 out_shape = [-1,ndim_tot] if conv==True: in_shape = [-1,1,ndim_tot] else: in_shape = [-1,ndim_tot] # generate test data x_test, y_test, x, sig_test, parnames = data_maker.generate( tot_dataset_size=Nsamp, ndata=ndim_y, usepars=usepars, sigma=sigma, seed=1 ) # run the x test data through the model x = torch.tensor(x_test,dtype=torch.float,device=dev).clone().detach() y_test = torch.tensor(y_test,dtype=torch.float,device=dev).clone().detach() sig_test = torch.tensor(sig_test,dtype=torch.float,device=dev).clone().detach() # make the new padding for the noisy data and latent vector data pad_x = torch.zeros(Nsamp,ndim_tot-ndim_x-ndim_y,device=dev) # make a padded zy vector (with all new noise) x_padded = torch.cat((x,pad_x,y_test-sig_test),dim=1) # apply forward model to the x data if do_double_nn: if do_cnn: data = torch.cat((x,y_test-sig_test), dim=1) output = model_f(data.reshape(data.shape[0],1,data.shape[1]))#.reshape(out_shape) output_z = output[:,ndim_y:] # extract the model output y else: output = model_f(torch.cat((x,y_test-sig_test), dim=1))#.reshape(out_shape) output_z = output[:,ndim_y:] # extract the model output y else: output = model(x_padded.reshape(in_shape))#.reshape(out_shape) output_z = output[:,model.outSchema.LatentSpace] # extract the model output y z = output_z.cpu().data.numpy() C = np.cov(z.transpose()) fig, axes = plt.subplots(1,figsize=(5,5)) im = axes.imshow(np.abs(C)) # We want to show all ticks... axes.set_xticks(np.arange(ndim_z)) axes.set_yticks(np.arange(ndim_z)) # Rotate the tick labels and set their alignment. plt.setp(axes.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. for i in range(ndim_z): for j in range(ndim_z): text = axes.text(j,i,'%.2f' % C[i,j], fontsize=3, ha="center",va="center",color="w") fig.tight_layout() fig.savefig('%s/cov_z_%04d.png' % (outdir,i_epoch),dpi=360) fig.savefig('%s/latest/latest_cov_z.png' % outdir,dpi=360) plt.close(fig) fig, axes = plt.subplots(ndim_z,ndim_z,figsize=(5,5)) for c in range(ndim_z): for d in range(ndim_z): if d<c: patches = [] axes[c,d].clear() matplotlib.rc('xtick', labelsize=8) matplotlib.rc('ytick', labelsize=8) axes[c,d].plot(z[:,c],z[:,d],'.r',markersize=0.5) circle1 = Circle((0.0, 0.0), 1.0,fill=False,linestyle='--') patches.append(circle1) circle2 = Circle((0.0, 0.0), 2.0,fill=False,linestyle='--') patches.append(circle2) circle3 = Circle((0.0, 0.0), 3.0,fill=False,linestyle='--') patches.append(circle3) p = PatchCollection(patches, alpha=0.2) axes[c,d].add_collection(p) axes[c,d].set_yticklabels([]) axes[c,d].set_xticklabels([]) axes[c,d].set_xlim([-3,3]) axes[c,d].set_ylim([-3,3]) else: axes[c,d].axis('off') axes[c,d].set_xlabel('') axes[c,d].set_ylabel('') fig.savefig('%s/scatter_z_%04d.png' % (outdir,i_epoch),dpi=360) fig.savefig('%s/latest/latest_scatter_z.png' % outdir,dpi=360) plt.close(fig) fig, axes = plt.subplots(1,figsize=(5,5)) delta = np.transpose(z[:,:]) dyvec = np.linspace(-10*1.0,10*1.0,250) for d in delta: plt.hist(np.array(d).flatten(),25,density=True,histtype='stepfilled',alpha=0.5) plt.hist(np.array(delta).flatten(),25,density=True,histtype='step',linestyle='dashed') plt.plot(dyvec,norm.pdf(dyvec,loc=0,scale=1.0),'k-') plt.xlabel('predicted z') plt.ylabel('p(z)') fig.savefig('%s/dist_z_%04d.png' % (outdir,i_epoch),dpi=360) fig.savefig('%s/latest/latest_dist_z.png' % outdir,dpi=360) plt.close(fig) return
def plot_y_dist(model,ndim_x,ndim_y,ndim_z,ndim_tot,usepars,sigma,outdir,i_epoch,conv=False,model_f=None,model_r=None,do_double_nn=False,do_cnn=False): """ Plots the joint distributions of y variables """ Nsamp = 1000 out_shape = [-1,ndim_tot] if conv==True: in_shape = [-1,1,ndim_tot] else: in_shape = [-1,ndim_tot] # generate test data x_test, y_test, x, sig_test, parnames = data_maker.generate( tot_dataset_size=Nsamp, ndata=ndim_y, usepars=usepars, sigma=sigma, seed=1 ) # run the x test data through the model x = torch.tensor(x_test,dtype=torch.float,device=dev).clone().detach() y_test = torch.tensor(y_test,dtype=torch.float,device=dev).clone().detach() sig_test = torch.tensor(sig_test,dtype=torch.float,device=dev).clone().detach() # make the new padding for the noisy data and latent vector data pad_x = torch.zeros(Nsamp,ndim_tot-ndim_x-ndim_y,device=dev) # make a padded zy vector (with all new noise) x_padded = torch.cat((x,pad_x,y_test-sig_test),dim=1) # apply forward model to the x data if do_double_nn: if do_cnn: data = torch.cat((x,y_test-sig_test), dim=1) output = model_f(data.reshape(data.shape[0],1,data.shape[1]))#.reshape(out_shape) output_y = output[:,:ndim_y] # extract the model output y else: output = model_f(torch.cat((x,y_test-sig_test), dim=1))#.reshape(out_shape) output_y = output[:,:ndim_y] # extract the model output y else: output = model(x_padded.reshape(in_shape)) output_y = output[:, model.outSchema.timeseries] y = output_y.cpu().data.numpy() sig_test = sig_test.cpu().data.numpy() dy = y - sig_test C = np.cov(dy.transpose()) fig, axes = plt.subplots(1,figsize=(5,5)) im = axes.imshow(C) # We want to show all ticks... axes.set_xticks(np.arange(ndim_y)) axes.set_yticks(np.arange(ndim_y)) # Rotate the tick labels and set their alignment. plt.setp(axes.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. for i in range(ndim_y): for j in range(ndim_y): text = axes.text(j,i,'%.2f' % C[i,j], fontsize=3, ha="center",va="center",color="w") fig.tight_layout() plt.savefig('%s/cov_y_%04d.png' % (outdir,i_epoch),dpi=360) plt.savefig('%s/latest/latest_cov_y.png' % outdir,dpi=360) plt.close(fig) fig, axes = plt.subplots(1,figsize=(5,5)) delta = np.transpose(y[:,:]-sig_test[:,:]) dyvec = np.linspace(-10*sigma,10*sigma,250) for d in delta: plt.hist(np.array(d).flatten(),25,density=True,histtype='stepfilled',alpha=0.5) plt.hist(np.array(delta).flatten(),25,density=True,histtype='step',linestyle='dashed') plt.plot(dyvec,norm.pdf(dyvec,loc=0,scale=np.sqrt(2.0)*sigma),'k-') plt.xlabel('y-y_pred') plt.ylabel('p(y-y_pred)') plt.savefig('%s/y_dist_%04d.png' % (outdir,i_epoch),dpi=360) plt.savefig('%s/latest/y_dist.png' % outdir,dpi=360) plt.close(fig) return return
def train(self, epoch, gen_inf_temp=False, extra_z=False, do_covar=False): self.model.train() lTot = 0 miniBatchIdx = 0 if self.fadeIn: # normally at 0.4 wRevScale = min(epoch / 400.0, 1)**3 self.wRevScale = wRevScale else: wRevScale = 1.0 self.wRevScale = wRevScale noiseScale = (1.0 - wRevScale) * self.zerosNoiseScale pad_fn = lambda *x: noiseScale * torch.randn(*x, device=self.dev) #+ 10 * torch.ones(*x, device=self.dev) randn = lambda *x: torch.randn(*x, device=self.dev) losses = [0, 0, 0, 0] for x, y, y_sig in self.atmosData.trainLoader: miniBatchIdx += 1 if miniBatchIdx > self.miniBatchesPerEpoch: break # if true, generate templates on the fly during training if gen_inf_temp: del x, y pos, labels, _, y_sig, _ = data_maker.generate( tot_dataset_size=2*self.batchSize, ndata=self.ndata, usepars=self.usepars, sigma=self.sigma, seed=np.random.randint(int(1e9)) ) loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(torch.tensor(pos), torch.tensor(labels), torch.tensor(y_sig)), batch_size=self.batchSize, shuffle=True, drop_last=True) for x, y, y_sig in loader: x = x y = y y_sig = y_sig break n = y - y_sig x, y, y_sig, n = x.to(self.dev), y.to(self.dev), y_sig.to(self.dev), n.to(self.dev) yClean = y.clone() # loss factors loss_factor_fwd_mmd_z = 1.0 #min(epoch / 400.0, 1)**3 loss_factor_rev_mse_n = min(epoch / 500.0, 1)**3 # 1000 loss_factor_fwd_mse_y = min(epoch / 750.0, 1)**3 # 2000 loss_factor_rev_mse_x = min(epoch / 10000.0, 1)**3 # 3000 if extra_z: xzp = self.model.inSchema.fill({'amp': x[:, 0], 't0': x[:, 1], 'tau': x[:, 2], 'yNoise': n[:]}, zero_pad_fn=pad_fn) else: xp = self.model.inSchema.fill({'amp': x[:, 0], 't0': x[:, 1], 'tau': x[:, 2]}, zero_pad_fn=pad_fn) if self.y_noise_scale: y += self.y_noise_scale * torch.randn(self.batchSize, self.ndata, dtype=torch.float, device=self.dev) yzp = self.model.outSchema.fill({'timeseries': y[:], 'LatentSpace': randn}, zero_pad_fn=pad_fn) y_sig_zp = self.model.outSchema.fill({'timeseries': y_sig[:], 'LatentSpace': randn}, zero_pad_fn=pad_fn) yzpRevRand = self.model.outSchema.fill({'timeseries': yClean[:], 'LatentSpace': randn}, zero_pad_fn=pad_fn) self.optim.zero_grad() if extra_z: out= self.model(xzp) else: out = self.model(xp) # lForward = self.wPred * (self.loss_fit(y[:, 0], out[:, self.model.outSchema.Halpha]) + # self.loss_fit(y[:, 1], out[:, self.model.outSchema.Ca8542])) # lForward = self.wPred * self.loss_fit(yzp[:, :self.model.outSchema.LatentSpace[0]], out[:, :self.model.outSchema.LatentSpace[0]]) # add z space onto x-space #if extra_z: # out_fmse = torch.cat((self.model(yzpRevRand, rev=True)[:, self.model.inSchema.LatentSpace], xzp[:, self.model.outSchema.LatentSpace[-1]+1:]), # dim=1) # out_fmse = self.model(out_fmse) # lForward = self.wPred * self.loss_fit(yzp[:, self.model.outSchema.LatentSpace[-1]+1:], # out_fmse[:, self.model.outSchema.LatentSpace[-1]+1:]) # use a covariance loss on forward mean squared error if do_covar: # try covariance fit output_cov = self.cov((out[:, self.model.outSchema.LatentSpace[-1]+1:]-y_sig_zp[:, self.model.outSchema.LatentSpace[-1]+1:]).transpose(0,1)) ycov_mat = self.sigma*self.sigma*torch.eye((self.ndata+self.n_neurons),device=self.dev) lForward = self.wPred * self.loss_fit(output_cov.flatten(), ycov_mat.flatten()) else: # compute mean squared error on only y lForward = loss_factor_fwd_mse_y * self.wPred * self.loss_fit(yzp[:, self.model.outSchema.LatentSpace[-1]+1:], out[:, self.model.outSchema.LatentSpace[-1]+1:]) losses[0] += lForward.data.item() / self.wPred #lForward_extra = self.wPred * self.loss_fit(yzp[:, self.model.outSchema.LatentSpace[-1]+1:], # out[:, self.model.outSchema.LatentSpace[-1]+1:]) #lForward += lForward_extra #losses[0] += lForward_extra.data.item() / self.wPred #if do_covar: do_z_covar=False if do_z_covar: # compute MMD loss on forward time series prediction lforward21Pred=out[:, self.model.outSchema.timeseries].data lforward21Target=yzp[:, self.model.outSchema.timeseries] lForward21 = self.wLatent * self.loss_latent(lforward21Pred, lforward21Target) losses[1] += lForward21.data.item() / self.wLatent lForward += lForward21 # compute variance loss on z (2nd moment) lforward22Pred = self.cov((out[:, self.model.outSchema.LatentSpace]).transpose(0,1)) lforward22Target = 1.0 * torch.eye((out[:, self.model.outSchema.LatentSpace].shape[1]),device=self.dev) lForward22 = self.wLatent * self.loss_fit(lforward22Pred.flatten(), lforward22Target.flatten()) losses[1] += lForward22.data.item() / self.wLatent lForward += lForward22 # compute mean loss (1st moment) on z lForwardFirstMom = self.wLatent * self.loss_fit(torch.mean(out[:, self.model.outSchema.LatentSpace]), torch.tensor(0.0)) losses[1] += lForwardFirstMom.data.item() / self.wLatent lForward += lForwardFirstMom # compute 3rd moment loss on z lForwardThirdMom = self.wLatent * self.loss_fit(torch.mean(out[:, self.model.outSchema.LatentSpace])**3 + 3*torch.mean(out[:, self.model.outSchema.LatentSpace])*lforward22Pred.flatten(), torch.tensor(0.0)) losses[1] += lForwardThirdMom.data.item() / self.wLatent lForward += lForwardThirdMom # compute 4th moment loss on z lForwardFourthMom = self.wLatent * self.loss_fit(torch.mean(out[:, self.model.outSchema.LatentSpace])**4 + 6*(torch.mean(out[:, self.model.outSchema.LatentSpace])**2)*lforward22Pred.flatten() + (3*lforward22Pred.flatten()**2), (3.0 * torch.eye((out[:, self.model.outSchema.LatentSpace].shape[1]),device=self.dev)).flatten()) losses[1] += lForwardFourthMom.data.item() / self.wLatent lForward += lForwardFourthMom # compute 5th moment loss on z lForwardFifthMom = self.wLatent * self.loss_fit(torch.mean(out[:, self.model.outSchema.LatentSpace])**5 + 10*(torch.mean(out[:, self.model.outSchema.LatentSpace])**3)*lforward22Pred.flatten() + (15*torch.mean(out[:, self.model.outSchema.LatentSpace])*lforward22Pred.flatten()**2), (0.0 * torch.eye((out[:, self.model.outSchema.LatentSpace].shape[1]),device=self.dev)).flatten()) losses[1] += lForwardFifthMom.data.item() / self.wLatent lForward += lForwardFifthMom else: # compute forward MMD on z data outLatentGradOnly = torch.cat((out[:, self.model.outSchema.timeseries].data, out[:, self.model.outSchema.LatentSpace]), dim=1) unpaddedTarget = torch.cat((yzp[:, self.model.outSchema.timeseries], yzp[:, self.model.outSchema.LatentSpace]), dim=1) lForward2 = loss_factor_fwd_mmd_z * self.wLatent * self.loss_latent(out[:, self.model.outSchema.LatentSpace], yzp[:, self.model.outSchema.LatentSpace]) losses[1] += lForward2.data.item() / self.wLatent lForward += lForward2 lTot += lForward.data.item() lForward.backward() yzpRev = self.model.outSchema.fill({'timeseries': yClean[:], 'LatentSpace': out[:, self.model.outSchema.LatentSpace].data}, zero_pad_fn=pad_fn) if extra_z: outRev = self.model(yzpRev, rev=True) #outRev = torch.cat((outRev[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1], # outRev[:, self.model.inSchema.yNoise]), # dim=1) outRevRand = self.model(yzpRevRand, rev=True) outRevRand = torch.cat((outRevRand[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1], outRevRand[:, self.model.inSchema.yNoise]), dim=1) else: outRev = self.model(yzpRev, rev=True) outRevRand = self.model(yzpRevRand, rev=True) # THis guy should have been OUTREVRAND!!! # xBack = torch.cat((outRevRand[:, self.model.inSchema.ne], # outRevRand[:, self.model.inSchema.temperature], # outRevRand[:, self.model.inSchema.vel]), # dim=1) # lBackward = self.wRev * wRevScale * self.loss_backward(xBack, x.reshape(self.miniBatchSize, -1)) if extra_z: #xzp_bMMD=torch.cat((xzp[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1], # xzp[:, self.model.inSchema.yNoise]), dim=1) #lBackward = self.wRev * wRevScale * self.loss_fit(outRev, # xzp_bMMD) kld_loss = torch.nn.SmoothL1Loss(size_average=None, reduce=None, reduction='mean') lBackward1 = loss_factor_rev_mse_x * self.wRev * kld_loss(outRev[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1], xzp[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1]) lBackward2 = loss_factor_rev_mse_n * self.wRev * kld_loss(outRev[:, self.model.inSchema.yNoise], xzp[:, self.model.inSchema.yNoise]) lBackward = lBackward1 + lBackward2 else: lBackward = self.wRev * wRevScale * self.loss_backward(outRev[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1], xp[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1]) scale = wRevScale if wRevScale != 0 else 1.0 losses[2] += (lBackward1.data.item() / (self.wRev * scale)) + (lBackward1.data.item() / (self.wRev * scale)) #TODO: may need to uncomment this #lBackward2 += 0.5 * self.wPred * self.loss_fit(outRev, xp) if extra_z: #lBackward2 = 0.5 * self.wPred * self.loss_fit(outRev, # xzp_bMMD) losses[3] += 0.0 lTot += lBackward.data.item() else: lBackward2 = 0.5 * self.wPred * self.loss_fit(outRev[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1], xp[:, self.model.inSchema.amp[0]:self.model.inSchema.tau[-1]+1]) losses[3] += lBackward2.data.item() / self.wPred * 2 lBackward += lBackward2 lTot += lBackward.data.item() lBackward.backward() for p in self.model.parameters(): p.grad.data.clamp_(-15.0, 15.0) self.optim.step() losses = [l / miniBatchIdx for l in losses] return lTot / miniBatchIdx, losses
def main(): # Set up data batch_size = 1600 # set batch size test_split = 10000 # number of testing samples to use # generate data # makes a torch.tensor() with arrays of (n_samples X parameters) and (n_samples X data) # labels are the colours and pos are the x,y coords # however, labels are 1-hot encoded pos, labels = data.generate(labels='all', tot_dataset_size=2**20) # just simply renaming the colors properly. #c = np.where(labels[:test_split])[1] #c = labels[:test_split,:] plt.figure(figsize=(6, 6)) r = 4 fig, axs = plt.subplots(r, r) cnt = 0 for i in range(r): for j in range(r): axs[i, j].plot(np.arange(3) + 1, np.array(pos[cnt, :]), '.') axs[i, j].plot([1, 3], [labels[cnt, 0], labels[cnt, 0]], 'k-') axs[i, j].plot([1, 3], [ labels[cnt, 0] + labels[cnt, 1], labels[cnt, 0] + labels[cnt, 1] ], 'k--') axs[i, j].plot([1, 3], [ labels[cnt, 0] - labels[cnt, 1], labels[cnt, 0] - labels[cnt, 1] ], 'k--') axs[i, j].set_ylim([-1, 2]) cnt += 1 plt.savefig('/data/public_html/chrism/FrEIA/test_distribution.png') plt.close() # setting up the model ndim_tot = 16 # ? ndim_x = 2 # number of parameter dimensions (mu,sig) ndim_y = 3 # number of label dimensions (data) ndim_z = 2 # number of latent space dimensions? # define different parts of the network # define input node inp = InputNode(ndim_tot, name='input') # define hidden layer nodes t1 = Node([inp.out0], rev_multiplicative_layer, { 'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': { 'dropout': 0.0 } }) t2 = Node([t1.out0], rev_multiplicative_layer, { 'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': { 'dropout': 0.0 } }) t3 = Node([t2.out0], rev_multiplicative_layer, { 'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': { 'dropout': 0.0 } }) # define output layer node outp = OutputNode([t3.out0], name='output') nodes = [inp, t1, t2, t3, outp] model = ReversibleGraphNet(nodes) # Train model # Training parameters n_epochs = 3000 meta_epoch = 12 # what is this??? n_its_per_epoch = 4 batch_size = 1600 lr = 1e-2 gamma = 0.01**(1. / 120) l2_reg = 2e-5 y_noise_scale = 3e-2 zeros_noise_scale = 3e-2 # relative weighting of losses: lambd_predict = 300. # forward pass lambd_latent = 300. # laten space lambd_rev = 400. # backwards pass # padding both the data and the latent space # such that they have equal dimension to the parameter space pad_x = torch.zeros(batch_size, ndim_tot - ndim_x) pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z) print(pad_x.shape, pad_yz.shape) # define optimizer optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.8, 0.8), eps=1e-04, weight_decay=l2_reg) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=meta_epoch, gamma=gamma) # define the three loss functions loss_backward = MMD_multiscale loss_latent = MMD_multiscale loss_fit = fit # set up test set data loader test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset( pos[:test_split], labels[:test_split]), batch_size=batch_size, shuffle=True, drop_last=True) # set up training set data loader train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset( pos[test_split:], labels[test_split:]), batch_size=batch_size, shuffle=True, drop_last=True) # initialisation of network weights for mod_list in model.children(): for block in mod_list.children(): for coeff in block.children(): coeff.fc3.weight.data = 0.01 * torch.randn( coeff.fc3.weight.shape) model.to(device) # initialize gif for showing training procedure #fig, axes = plt.subplots(1, 2, figsize=(8,4)) #axes[0].set_xticks([]) #axes[0].set_yticks([]) #axes[0].set_title('Predicted labels (Forwards Process)') #axes[1].set_xticks([]) #axes[1].set_yticks([]) #axes[1].set_title('Generated Samples (Backwards Process)') #fig.show() #fig.canvas.draw() # number of test samples to use after training N_samp = 4096 # choose test samples to use after training x_samps = torch.cat([x for x, y in test_loader], dim=0)[:N_samp] y_samps = torch.cat([y for x, y in test_loader], dim=0)[:N_samp] print(np.array(y_samps)) #c = np.where(y_samps)[1] c = np.array(y_samps).reshape(-1, 3) y_samps += y_noise_scale * torch.randn(N_samp, ndim_y) y_samps = torch.cat([ torch.randn(N_samp, ndim_z), zeros_noise_scale * torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z), y_samps ], dim=1) y_samps = y_samps.to(device) #y_samps = np.random.normal(loc=0.5,scale=0.75,size=3).reshape(-1,3) # start training loop try: # print('#Epoch \tIt/s \tl_total') t_start = time() # loop over number of epochs for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80): scheduler.step() # Initially, the l2 reg. on x and z can give huge gradients, set # the lr lower for this if i_epoch < 0: for param_group in optimizer.param_groups: param_group['lr'] = lr * 1e-2 # print(i_epoch, end='\t ') train(model, train_loader, n_its_per_epoch, zeros_noise_scale, batch_size, ndim_tot, ndim_x, ndim_y, ndim_z, y_noise_scale, optimizer, lambd_predict, loss_fit, lambd_latent, loss_latent, lambd_rev, loss_backward, i_epoch) # predict the mu and sig of test data rev_x = model(y_samps, rev=True) rev_x = rev_x.cpu().data.numpy() #print(rev_x) # predict the label given a location #pred_c = model(torch.cat((x_samps, torch.zeros(N_samp, ndim_tot - ndim_x)), # dim=1).to(device)).data[:, -8:].argmax(dim=1) #pred_c = model(torch.cat((x_samps, torch.zeros(N_samp, ndim_tot - ndim_x)), # dim=1).to(device)).data[:, -1:].argmax(dim=1) #axes[0].clear() #axes[0].scatter(tmp_x_samps[:,0], tmp_x_samps[:,1], c=pred_c, cmap='Set1', s=1., vmin=0, vmax=9) #axes[0].axis('equal') #axes[0].axis([-3,3,-3,3]) #axes[0].set_xticks([]) #axes[0].set_yticks([]) axes[1].clear() axes[1].scatter(rev_x[:, 0], rev_x[:, 1], c=c, cmap='Set1', s=1., vmin=0, vmax=9) axes[1].axis('equal') axes[1].axis([-3, 3, -3, 3]) axes[1].set_xticks([]) axes[1].set_yticks([]) fig.canvas.draw() plt.savefig('/data/public_html/chrism/FrEIA/training_pred.png') except KeyboardInterrupt: pass finally: print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")
def train(self, epoch, gen_inf_temp=False, extra_z=False, do_cnn=False): self.model_f.train() self.model_r.train() lTot = 0 miniBatchIdx = 0 randn = torch.randn(self.batchSize, self.ndata, dtype=torch.float, device=self.dev) optimizer_f = self.optim_f #optim.SGD(self.model_f.parameters(), lr=0.01, weight_decay= 1e-6, momentum = 0.9, nesterov = True) optimizer_r = self.optim_r #optim.SGD(self.model_r.parameters(), lr=0.01, weight_decay= 1e-6, momentum = 0.9, nesterov = True) losses = [0, 0, 0, 0] # get data for x, y, y_sig in self.atmosData.trainLoader: miniBatchIdx += 1 if miniBatchIdx > self.miniBatchesPerEpoch: break # if true, generate templates on the fly during training if gen_inf_temp: del x, y pos, labels, _, y_sig, _ = data_maker.generate( tot_dataset_size=2 * self.batchSize, ndata=self.ndata, usepars=self.usepars, sigma=self.sigma, seed=np.random.randint(int(1e9))) loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(torch.tensor(pos), torch.tensor(labels), torch.tensor(y_sig)), batch_size=self.batchSize, shuffle=True, drop_last=True) for x, y, y_sig in loader: x = x y = y y_sig = y_sig break n = y - y_sig x, y, y_sig, n = x.to(self.dev), y.to(self.dev), y_sig.to( self.dev), n.to(self.dev) yClean = y.clone() optimizer_f.zero_grad() optimizer_r.zero_grad() ################# # forward process ################# ## 1. forward propagation data = torch.cat((x[:], n[:]), dim=1) if do_cnn: data = data.reshape(data.shape[0], 1, data.shape[1]) output = self.model_f(data) ## 2. loss calculation target = torch.cat((y[:], randn), dim=1) loss_y = self.loss_fit(output[:, :y.shape[1]], y[:]) ## 3. backward propagation loss_y.backward(retain_graph=True) losses[0] += loss_y.data.item() loss_z = self.loss_latent(output[:, y.shape[1]:], randn) ## 3. backward propagation loss_z.backward() ## 4. weight optimization optimizer_f.step() losses[1] += loss_z.data.item() ################# # reverse process ################# ## 1. forward propagation output = torch.cat((y[:], output[:, y.shape[1]:]), dim=1) if do_cnn: output = output.reshape(output.shape[0], 1, output.shape[1]) output = self.model_r(output.data) ## 2. loss calculation target = torch.cat((x[:], n[:]), dim=1) loss_r = self.loss_fit(output, target) ## 3. backward propagation loss_r.backward() ## 4. weight optimization optimizer_r.step() losses[2] += loss_r.data.item() losses[3] += 0.0 # dummy loss for now lTot += losses[0] + losses[1] + losses[2] + losses[3] losses = [l / miniBatchIdx for l in losses] return lTot / miniBatchIdx, losses
def main(): # generate data # generate data if not load_dataset: pos, labels, x, sig, parnames = data_maker.generate( tot_dataset_size=tot_dataset_size, ndata=ndata, usepars=usepars, sigma=sigma, seed=seed) print('generated data') hf = h5py.File('benchmark_data_%s.h5py' % run_label, 'w') hf.create_dataset('pos', data=pos) hf.create_dataset('labels', data=labels) hf.create_dataset('x', data=x) hf.create_dataset('sig', data=sig) hf.create_dataset('parnames', data=np.string_(parnames)) data = AtmosData([dataLocation1], test_split, resampleWl=None) data.split_data_and_init_loaders(batchsize) # seperate the test data for plotting pos_test = data.pos_test labels_test = data.labels_test sig_test = data.sig_test ndim_x = len(usepars) print('Computing MCMC posterior samples') if do_mcmc or not load_dataset: # precompute true posterior samples on the test data cnt = 0 samples = np.zeros((r * r, N_samp, ndim_x)) for i in range(r): for j in range(r): samples[cnt, :, :] = data_maker.get_lik( np.array(labels_test[cnt, :]).flatten(), np.array(pos_test[cnt, :]), out_dir, cnt, sigma=sigma, usepars=usepars, Nsamp=N_samp) print(samples[cnt, :10, :]) cnt += 1 # save computationaly expensive mcmc/waveform runs if load_dataset == True: hf = h5py.File('benchmark_data_%s.h5py' % run_label, 'w') hf.create_dataset('pos', data=data.pos) hf.create_dataset('labels', data=data.labels) hf.create_dataset('x', data=data.x) hf.create_dataset('sig', data=data.sig) hf.create_dataset('parnames', data=parnames) hf.create_dataset('samples', data=np.string_(samples)) hf.close() else: samples = h5py.File(dataLocation1, 'r')['samples'][:] parnames = h5py.File(dataLocation1, 'r')['parnames'][:] # plot the test data examples plt.figure(figsize=(6, 6)) fig, axes = plt.subplots(r, r, figsize=(6, 6), sharex='col', sharey='row') cnt = 0 for i in range(r): for j in range(r): axes[i, j].plot(data.x, np.array(labels_test[cnt, :]), '.') axes[i, j].plot(data.x, np.array(sig_test[cnt, :]), '-') cnt += 1 axes[i, j].axis([0, 1, -1.5, 1.5]) axes[i, j].set_xlabel('time') if i == r - 1 else axes[ i, j].set_xlabel('') axes[i, j].set_ylabel('h(t)') if j == 0 else axes[i, j].set_ylabel('') plt.savefig('%stest_distribution.png' % out_dir, dpi=360) plt.close() # initialize plot for showing testing results fig, axes = plt.subplots(r, r, figsize=(6, 6), sharex='col', sharey='row') for k in range(ndim_x): parname1 = parnames[k] for nextk in range(ndim_x): parname2 = parnames[nextk] if nextk > k: cnt = 0 for i in range(r): for j in range(r): # plot the samples and the true contours axes[i, j].clear() axes[i, j].scatter(samples[cnt, :, k], samples[cnt, :, nextk], c='b', s=0.5, alpha=0.5) axes[i, j].plot(pos_test[cnt, k], pos_test[cnt, nextk], '+c', markersize=8) axes[i, j].set_xlim([0, 1]) axes[i, j].set_ylim([0, 1]) axes[i, j].set_xlabel( parname1) if i == r - 1 else axes[i, j].set_xlabel('') axes[i, j].set_ylabel(parname2) if j == 0 else axes[ i, j].set_ylabel('') cnt += 1 # save the results to file fig.canvas.draw() plt.savefig('%strue_samples_%d%d.png' % (out_dir, k, nextk), dpi=360) def store_pars(f, pars): for i in pars.keys(): f.write("%s: %s\n" % (i, str(pars[i]))) f.close() # store hyperparameters for posterity f = open("%s_run-pars.txt" % run_label, "w+") pars_to_store = { "sigma": sigma, "ndata": ndata, "T": T, "seed": seed, "n_neurons": n_neurons, "bound": bound, "conv_nn": conv_nn, "filtsize": filtsize, "dropout": dropout, "clamp": clamp, "ndim_z": ndim_z, "tot_epoch": tot_epoch, "lr": lr, "latentAlphas": latentAlphas, "backwardAlphas": backwardAlphas, "zerosNoiseScale": zerosNoiseScale, "wPred": wPred, "wLatent": wLatent, "wRev": wRev, "tot_dataset_size": tot_dataset_size, "numInvLayers": numInvLayers, "batchsize": batchsize } store_pars(f, pars_to_store) # setup output directory - if it does not exist os.system('mkdir -p %s' % out_dir) inRepr = [('amp', 1), ('t0', 1), ('tau', 1), ('phi', 1), ('!!PAD', )] outRepr = [('LatentSpace', ndim_z), ('!!PAD', ), ('timeseries', data.atmosOut.shape[1])] model = RadynversionNet(inRepr, outRepr, dropout=dropout, zeroPadding=0, minSize=ndim_tot, numInvLayers=numInvLayers) # Construct the class that trains the model, the initial weighting between the losses, learning rate, and the initial number of epochs to train for. trainer = RadynversionTrainer(model, data, dev) trainer.training_params( tot_epoch, lr=lr, zerosNoiseScale=zerosNoiseScale, wPred=wPred, wLatent=wLatent, wRev=wRev, loss_latent=Loss.mmd_multiscale_on(dev, alphas=latentAlphas), loss_backward=Loss.mmd_multiscale_on(dev, alphas=backwardAlphas), loss_fit=Loss.mse) totalEpochs = 0 # Train the model for these first epochs with a nice graph that updates during training. losses = [] beta_score_hist = [] beta_score_loop_hist = [] lossVec = [[] for _ in range(4)] lossLabels = ['L2 Line', 'MMD Latent', 'MMD Reverse', 'L2 Reverse'] out = None alphaRange, mmdF, mmdB, idxF, idxB = [1, 1], [1, 1], [1, 1], 0, 0 try: tStart = time() olvec = np.zeros((r, r, int(tot_epoch / plot_cadence))) s = 0 for epoch in range(trainer.numEpochs): print('Epoch %s/%s' % (str(epoch), str(trainer.numEpochs))) totalEpochs += 1 trainer.scheduler.step() loss, indLosses = trainer.train(epoch) # loop over a few cases and plot results in a grid if np.remainder(epoch, plot_cadence) == 0: for k in range(ndim_x): parname1 = parnames[k] for nextk in range(ndim_x): parname2 = parnames[nextk] if nextk > k: cnt = 0 # initialize 2D plots for showing testing results fig, axes = plt.subplots(r, r, figsize=(6, 6), sharex='col', sharey='row') # initialize 1D plots for showing testing results fig_1d, axes_1d = plt.subplots(r, r, figsize=(6, 6), sharex='col', sharey='row') for i in range(r): for j in range(r): # convert data into correct format y_samps = np.tile( np.array(labels_test[cnt, :]), N_samp).reshape(N_samp, ndim_y) y_samps = torch.tensor(y_samps, dtype=torch.float) y_samps += y_noise_scale * torch.randn( N_samp, ndim_y) y_samps = torch.cat([ torch.randn(N_samp, ndim_z), zerosNoiseScale * torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z), y_samps ], dim=1) y_samps = y_samps.to(dev) # use the network to predict parameters rev_x = model(y_samps, rev=True) rev_x = rev_x.cpu().data.numpy() # compute the n-d overlap if k == 0 and nextk == 1: ol = data_maker.overlap( samples[cnt, :, :ndim_x], rev_x[:, :ndim_x]) olvec[i, j, s] = ol def confidence_bd(samp_array): """ compute confidence bounds for a given array """ cf_bd_sum_lidx = 0 cf_bd_sum_ridx = 0 cf_bd_sum_left = 0 cf_bd_sum_right = 0 cf_perc = 0.05 cf_bd_sum_lidx = np.sort(samp_array)[ int(len(samp_array) * cf_perc)] cf_bd_sum_ridx = np.sort(samp_array)[ int( len(samp_array) * (1.0 - cf_perc))] return [cf_bd_sum_lidx, cf_bd_sum_ridx] # plot the 2D samples and the true contours true_cfbd_x = confidence_bd(samples[cnt, :, k]) true_cfbd_y = confidence_bd(samples[cnt, :, nextk]) pred_cfbd_x = confidence_bd(rev_x[:, k]) pred_cfbd_y = confidence_bd(rev_x[:, nextk]) axes[i, j].clear() axes[i, j].scatter(samples[cnt, :, k], samples[cnt, :, nextk], c='b', s=0.2, alpha=0.5) axes[i, j].scatter(rev_x[:, k], rev_x[:, nextk], c='r', s=0.2, alpha=0.5) axes[i, j].plot(pos_test[cnt, k], pos_test[cnt, nextk], '+c', markersize=8) #axes[i,j].axvline(x=true_cfbd_x[0], linewidth=0.5, color='b') #axes[i,j].axvline(x=true_cfbd_x[1], linewidth=0.5, color='b') #axes[i,j].axhline(y=true_cfbd_y[0], linewidth=0.5, color='b') #axes[i,j].axhline(y=true_cfbd_y[1], linewidth=0.5, color='b') #axes[i,j].axvline(x=pred_cfbd_x[0], linewidth=0.5, color='r') #axes[i,j].axvline(x=pred_cfbd_x[1], linewidth=0.5, color='r') #axes[i,j].axhline(y=pred_cfbd_y[0], linewidth=0.5, color='r') #axes[i,j].axhline(y=pred_cfbd_y[1], linewidth=0.5, color='r') axes[i, j].set_xlim([0, 1]) axes[i, j].set_ylim([0, 1]) oltxt = '%.2f' % olvec[i, j, s] axes[i, j].text(0.90, 0.95, oltxt, horizontalalignment='right', verticalalignment='top', transform=axes[i, j].transAxes) matplotlib.rc('xtick', labelsize=8) matplotlib.rc('ytick', labelsize=8) axes[i, j].set_xlabel( parname1) if i == r - 1 else axes[ i, j].set_xlabel('') axes[i, j].set_ylabel( parname2) if j == 0 else axes[ i, j].set_ylabel('') # plot the 1D samples and the 5% confidence bounds axes_1d[i, j].clear() axes_1d[i, j].hist(samples[cnt, :, k], color='b', bins=100, alpha=0.5) axes_1d[i, j].hist(rev_x[:, k], color='r', bins=100, alpha=0.5) axes_1d[i, j].axvline(x=pos_test[cnt, k], linewidth=0.5, color='black') axes_1d[i, j].axvline(x=confidence_bd( samples[cnt, :, k])[0], linewidth=0.5, color='b') axes_1d[i, j].axvline(x=confidence_bd( samples[cnt, :, k])[1], linewidth=0.5, color='b') axes_1d[i, j].axvline(x=confidence_bd( rev_x[:, k])[0], linewidth=0.5, color='r') axes_1d[i, j].axvline(x=confidence_bd( rev_x[:, k])[1], linewidth=0.5, color='r') axes_1d[i, j].set_xlim([0, 1]) axes_1d[i, j].text( 0.90, 0.95, oltxt, horizontalalignment='right', verticalalignment='top', transform=axes_1d[i, j].transAxes) axes_1d[i, j].set_xlabel( parname1) if i == r - 1 else axes_1d[ i, j].set_xlabel('') cnt += 1 # save the results to file fig_1d.canvas.draw() fig_1d.savefig('%sposteriors-1d_%d_%04d.png' % (out_dir, k, epoch), dpi=360) fig_1d.savefig('%slatest-1d_%d.png' % (out_dir, k), dpi=360) #fig_1d.close() fig.canvas.draw() fig.savefig('%sposteriors-2d_%d%d_%04d.png' % (out_dir, k, nextk, epoch), dpi=360) fig.savefig('%slatest-2d_%d%d.png' % (out_dir, k, nextk), dpi=360) #fig.close() s += 1 # plot overlap results if np.remainder(epoch, plot_cadence) == 0: fig_log = plt.figure(figsize=(6, 6)) axes_log = fig_log.add_subplot(1, 1, 1) for i in range(r): for j in range(r): axes_log.semilogx(np.arange(tot_epoch, step=plot_cadence), olvec[i, j, :], alpha=0.5) axes_log.grid() axes_log.set_ylabel('overlap') axes_log.set_xlabel('epoch (log)') axes_log.set_ylim([0, 1]) plt.savefig('%soverlap_logscale.png' % out_dir, dpi=360) plt.close() fig = plt.figure(figsize=(6, 6)) axes = fig.add_subplot(1, 1, 1) for i in range(r): for j in range(r): axes.plot(np.arange(tot_epoch, step=plot_cadence), olvec[i, j, :], alpha=0.5) axes.grid() axes.set_ylabel('overlap') axes.set_xlabel('epoch') axes.set_ylim([0, 1]) plt.savefig('%soverlap.png' % out_dir, dpi=360) plt.close() #egg = True #if egg==False: if np.remainder(epoch, plot_cadence) == 0 and (epoch > 5): fig, axis = plt.subplots(4, 1, figsize=(10, 8)) #fig.show() fig.canvas.draw() axis[0].clear() axis[1].clear() axis[2].clear() axis[3].clear() for i in range(len(indLosses)): lossVec[i].append(indLosses[i]) losses.append(loss) fig.suptitle('Current Loss: %.2e, min loss: %.2e' % (loss, np.nanmin(np.abs(losses)))) axis[0].semilogy(np.arange(len(losses)), np.abs(losses)) for i, lo in enumerate(lossVec): axis[1].semilogy(np.arange(len(losses)), lo, '--', label=lossLabels[i]) axis[1].legend(loc='upper left') tNow = time() elapsed = int(tNow - tStart) eta = int((tNow - tStart) / (epoch + 1) * trainer.numEpochs) - elapsed if epoch % 2 == 0: mses = trainer.test(maxBatches=1) lineProfiles = mses[2] if epoch % 10 == 0: alphaRange, mmdF, mmdB, idxF, idxB = trainer.review_mmd() axis[3].semilogx(alphaRange, mmdF, label='Latent Space') axis[3].semilogx(alphaRange, mmdB, label='Backward') axis[3].semilogx(alphaRange[idxF], mmdF[idxF], 'ro') axis[3].semilogx(alphaRange[idxB], mmdB[idxB], 'ro') axis[3].legend() testTime = time() - tNow axis[2].plot( lineProfiles[0, model.outSchema.timeseries].cpu().numpy()) for a in axis: a.grid() axis[3].set_xlabel( 'Epochs: %d, Elapsed: %d s, ETA: %d s (Testing: %d s)' % (epoch, elapsed, eta, testTime)) fig.canvas.draw() fig.savefig('%slosses.pdf' % out_dir) except KeyboardInterrupt: pass finally: print("\n\nTraining took {(time()-tStart)/60:.2f} minutes\n")
def main(): # Set up simulation parameters batch_size = 1600 # set batch size r = 3 # the grid dimension for the output tests test_split = r * r # number of testing samples to use sig_model = 'sg' # the signal model to use sigma = 0.2 # the noise std ndata = 32 # number of data samples bound = [0.0, 1.0, 0.0, 1.0] # effective bound for likelihood seed = 1 # seed for generating data # generate data pos, labels, x, sig = data.generate(model=sig_model, tot_dataset_size=2**20, ndata=ndata, sigma=sigma, prior_bound=bound, seed=seed) # seperate the test data for plotting pos_test = pos[-test_split:] labels_test = labels[-test_split:] sig_test = sig[-test_split:] # plot the test data examples plt.figure(figsize=(6, 6)) fig, axes = plt.subplots(r, r, figsize=(6, 6)) cnt = 0 for i in range(r): for j in range(r): axes[i, j].plot(x, np.array(labels_test[cnt, :]), '.') axes[i, j].plot(x, np.array(sig_test[cnt, :]), '-') cnt += 1 axes[i, j].axis([0, 1, -1.5, 1.5]) plt.savefig('/data/public_html/chrism/FrEIA/test_distribution.png', dpi=360) plt.close() # setting up the model ndim_x = 2 # number of posterior parameter dimensions (x,y) ndim_y = ndata # number of label dimensions (noisy data samples) ndim_z = 8 # number of latent space dimensions? ndim_tot = max(ndim_x, ndim_y + ndim_z) # must be > ndim_x and > ndim_y + ndim_z # define different parts of the network # define input node inp = InputNode(ndim_tot, name='input') # define hidden layer nodes t1 = Node([inp.out0], rev_multiplicative_layer, { 'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': { 'dropout': 0.2 } }) t2 = Node([t1.out0], rev_multiplicative_layer, { 'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': { 'dropout': 0.2 } }) t3 = Node([t2.out0], rev_multiplicative_layer, { 'F_class': F_fully_connected, 'clamp': 2.0, 'F_args': { 'dropout': 0.2 } }) # define output layer node outp = OutputNode([t3.out0], name='output') nodes = [inp, t1, t2, t3, outp] model = ReversibleGraphNet(nodes) # Train model # Training parameters n_epochs = 1000 meta_epoch = 12 # what is this??? n_its_per_epoch = 12 batch_size = 1600 lr = 1e-2 gamma = 0.01**(1. / 120) l2_reg = 2e-5 y_noise_scale = 3e-2 zeros_noise_scale = 3e-2 # relative weighting of losses: lambd_predict = 300. # forward pass lambd_latent = 300. # laten space lambd_rev = 400. # backwards pass # padding both the data and the latent space # such that they have equal dimension to the parameter space #pad_x = torch.zeros(batch_size, ndim_tot - ndim_x) #pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z) # define optimizer optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.8, 0.8), eps=1e-04, weight_decay=l2_reg) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=meta_epoch, gamma=gamma) # define the three loss functions loss_backward = MMD_multiscale loss_latent = MMD_multiscale loss_fit = fit # set up training set data loader train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset( pos[test_split:], labels[test_split:]), batch_size=batch_size, shuffle=True, drop_last=True) # initialisation of network weights for mod_list in model.children(): for block in mod_list.children(): for coeff in block.children(): coeff.fc3.weight.data = 0.01 * torch.randn( coeff.fc3.weight.shape) model.to(device) # initialize plot for showing testing results fig, axes = plt.subplots(r, r, figsize=(6, 6)) # number of test samples to use after training N_samp = 256 # precompute true likelihood on the test data Ngrid = 64 cnt = 0 lik = np.zeros((r, r, Ngrid * Ngrid)) for i in range(r): for j in range(r): mvec, cvec, temp = data.get_lik(np.array( labels_test[cnt, :]).flatten(), n_grid=Ngrid, sig_model=sig_model, sigma=sigma, xvec=x, bound=bound) lik[i, j, :] = temp.flatten() cnt += 1 # start training loop try: t_start = time() # loop over number of epochs for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80): scheduler.step() # Initially, the l2 reg. on x and z can give huge gradients, set # the lr lower for this if i_epoch < 0: print('inside this iepoch<0 thing') for param_group in optimizer.param_groups: param_group['lr'] = lr * 1e-2 # train the model train(model, train_loader, n_its_per_epoch, zeros_noise_scale, batch_size, ndim_tot, ndim_x, ndim_y, ndim_z, y_noise_scale, optimizer, lambd_predict, loss_fit, lambd_latent, loss_latent, lambd_rev, loss_backward, i_epoch) # loop over a few cases and plot results in a grid cnt = 0 for i in range(r): for j in range(r): # convert data into correct format y_samps = np.tile(np.array(labels_test[cnt, :]), N_samp).reshape(N_samp, ndim_y) y_samps = torch.tensor(y_samps, dtype=torch.float) #y_samps += y_noise_scale * torch.randn(N_samp, ndim_y) y_samps = torch.cat( [ torch.randn(N_samp, ndim_z), #zeros_noise_scale * torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z), y_samps ], dim=1) y_samps = y_samps.to(device) # use the network to predict parameters rev_x = model(y_samps, rev=True) rev_x = rev_x.cpu().data.numpy() # plot the samples and the true contours axes[i, j].clear() axes[i, j].contour(mvec, cvec, lik[i, j, :].reshape(Ngrid, Ngrid), levels=[0.68, 0.9, 0.99]) axes[i, j].scatter(rev_x[:, 0], rev_x[:, 1], s=0.5, alpha=0.5) axes[i, j].plot(pos_test[cnt, 0], pos_test[cnt, 1], '+r', markersize=8) axes[i, j].axis(bound) cnt += 1 # sve the results to file fig.canvas.draw() plt.savefig('/data/public_html/chrism/FrEIA/posteriors_%s.png' % i_epoch, dpi=360) plt.savefig('/data/public_html/chrism/FrEIA/latest.png', dpi=360) except KeyboardInterrupt: pass finally: print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")