def train(CurrentBatch, epoch, Encoder, Generator, optE, optG, criterionMSE, train_data, fixed_data, trainfolder, Reporter): """ Train function to train the networks. Inputs: CurrentBatch : Batch number. \n epoch : epoch number. \n Encoder : Encoder network. \n Generator : Generator network. \n optE : Optimizer for Encoder. \n optG : Optrimizer for Generator. \n criterionMSE : Mean-squared Error loss. \n train_data : Training DataLoader. \n fixed_data : Testing DataLoader. \n trainfolder : Location to store the plots. \n Reporter : Reporter object to show the progress. \n Returns: CurrentBatch+1 """ Encoder.train() Generator.train() for i, (image,label) in enumerate(train_data): image = image.to(_DEVICE) label = label.to(_DEVICE) ## Traditional Autoencoder training loop optE.zero_grad() optG.zero_grad() one_hot_labels = torch.nn.functional.one_hot(label, num_classes).float() features = Encoder(image,labels= one_hot_labels) #, mu, log_var recon_img = Generator(features,one_hot_labels,mode= "one_hot") # loss = VAE_LOSS(recon_img,image, mu, log_var) loss = criterionMSE(recon_img,image) loss.backward() optE.step() optG.step() # Testing loop built-in if CurrentBatch % len(train_data) == 0: loss_ae_test= test(Encoder, Generator, fixed_data, criterionMSE, epoch) ## Plotting reconstructions with torch.no_grad(): imag, label = next(iter(train_data)) imag= imag.to(_DEVICE) label = label.to(_DEVICE) one_hot_labels = torch.nn.functional.one_hot(label, num_classes).float() feat = Encoder(imag,labels= one_hot_labels) #,_,_ recon = Generator(feat,one_hot_labels,mode= "one_hot") img_title_list = ["Real","Recon"] PlotTitle = "train_epoch_"+str(epoch) FigureDict = HelpFunc.FigureDict(os.path.join(trainfolder,"train_plots"),dpi =300 ) PlotViz(imag, recon, img_title_list, FigureDict, PlotTitle, "cifar10") Encoder.eval() Generator.eval() image,label = next(iter(fixed_data)) image = image.to(_DEVICE) label = label.to(_DEVICE) one_hot_labels = torch.nn.functional.one_hot(label, num_classes).float() feat = Encoder(image,labels= one_hot_labels)#,_,_ recon_image = Generator(feat,one_hot_labels, mode= "one_hot") img_title_list = ["Real","Recon"] PlotTitle = "test_epoch_"+str(epoch) FigureDict = HelpFunc.FigureDict(os.path.join(trainfolder,"test_plots"),dpi =300) PlotViz(image, recon_image, img_title_list, FigureDict, PlotTitle, "cifar10") ## Reconstructions on fixed Noise fixed_image = Generator(fixed_noise,fixed_labels, mode ="one_hot") FigureDict = HelpFunc.FigureDict(os.path.join(trainfolder,"noise_recn"),dpi =300 ) fig_ae = plt.figure(figsize=(15,15),dpi=300) if fig_ae is not None: fig_ae.suptitle("Reconstructions on Fixed noise and labels") for i in range(100): ax=plt.subplot(10,10,i+1) imshow(fixed_image[i],'cifar10') plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.xticks([]) plt.yticks([]) ax.set_title("{}".format(class_labels[torch.argmax(fixed_labels[i])]) ) FigureDict.StoreFig(fig=fig_ae, name=PlotTitle,saving=True) plt.show() ## Reporting Reporter.DUMPDICT['Lr_Enc'].append(optE.param_groups[0]['lr']) Reporter.DUMPDICT['Lr_Gen'].append(optG.param_groups[0]['lr']) Reporter.SetValues([epoch+1, CurrentBatch, (time.time()-Reporter.DUMPDICT["starttime"])/60, loss.item(), loss_ae_test]) CurrentBatch+=1 return CurrentBatch
def train(CurrentBatch, epoch, ae, optE, criterionMSE, train_data, fixed_data, trainfolder, Reporter): """ Train function to train the networks. Inputs: CurrentBatch : Batch number. \n epoch : epoch number. \n Encoder : Encoder network. \n Generator : Generator network. \n optE : Optimizer for Encoder. \n optG : Optrimizer for Generator. \n criterionMSE : Mean-squared Error loss. \n train_data : Training DataLoader. \n fixed_data : Testing DataLoader. \n trainfolder : Location to store the plots. \n Reporter : Reporter object to show the progress. \n Returns: CurrentBatch+1 """ for i, (image, label) in enumerate(train_data): ae.train() image = image.to(_DEVICE) label = label.to(_DEVICE) ## Traditional Autoencoder training loop optE.zero_grad() one_hot_labels = torch.nn.functional.one_hot(label, num_classes).float() recon_img = ae(image, one_hot_labels, mode="one_hot") mse_loss = criterionMSE(recon_img, image) mse_loss.backward() optE.step() # Testing loop built-in if CurrentBatch % 500 == 0: loss_ae_test = test(ae, fixed_data, criterionMSE, epoch) ## Plotting reconstructions with torch.no_grad(): ae.train() imag, label = plot_data["image"], plot_data["label"] one_hot_labels = torch.nn.functional.one_hot( label, num_classes).float() recon = ae.forward(imag, one_hot_labels, mode="one_hot", train_mode="ae") img_title_list = ["Real", "Recon"] PlotTitle = "train_epoch_" + str(epoch) FigureDict = HelpFunc.FigureDict(os.path.join( trainfolder, "train_plots"), dpi=300) PlotViz(imag, recon, img_title_list, FigureDict, PlotTitle, "cifar10") ae.eval() image, label = plot_data['imageT'], plot_data['labelT'] one_hot_labels = torch.nn.functional.one_hot( label, num_classes).float() recon_image = ae.forward(image, one_hot_labels, mode="one_hot", train_mode="ae") img_title_list = ["Real", "Recon"] PlotTitle = "test_epoch_" + str(epoch) FigureDict = HelpFunc.FigureDict(os.path.join( trainfolder, "test_plots"), dpi=300) PlotViz(image, recon_image, img_title_list, FigureDict, PlotTitle, "cifar10") else: loss_ae_test = loss_test[-1] loss_test.append(loss_ae_test) ## Reporting Reporter.DUMPDICT['Lr_AE'].append(optE.param_groups[0]['lr']) Reporter.SetValues([ epoch + 1, CurrentBatch, (time.time() - Reporter.DUMPDICT["starttime"]) / 60, mse_loss.item(), loss_test[-1] ]) CurrentBatch += 1 return CurrentBatch
Reporter.DUMPDICT["starttime"] = time.time() Reporter.DUMPDICT['Lr_Enc'] = [] Reporter.DUMPDICT['Lr_Gen'] = [] epoch = 0 for epoch in range(n_epoch): CurrentBatch=train(CurrentBatch, epoch, enc, gen, optE, optG, criterionMSE, trainloader, testloader ,trainfolder, Reporter) schedE.step() schedG.step() #%% Plotting ############################## FigureDict = HelpFunc.FigureDict(os.path.join(trainfolder,"AE_metrics_Plots"),dpi =300 ) loss= HelpFunc.MovingAverage(Reporter.VALS['Mse_loss'],window=n_epoch) AE_TestLoss= HelpFunc.MovingAverage(Reporter.VALS['AE_TestLoss'],window=n_epoch) FigureLoss1=plt.figure(figsize=(8,5)) plt.plot(loss, label = "MSE_loss") plt.plot(AE_TestLoss, label='AE_TestLoss') plt.legend(loc='best') plt.xlabel('Steps') plt.ylabel('Loss') plt.ylim(bottom =0 ) plt.title("MSE loss: train = {:.5f} , test = {:.5f}".format(loss[-1],AE_TestLoss[-1])) plt.minorticks_on() FigureDict.StoreFig(fig=FigureLoss1, name="Autoencoder_loss", saving=True)
step=ReporterStep) CurrentBatch = 0 Reporter.DUMPDICT["starttime"] = time.time() Reporter.DUMPDICT['Lr_AE'] = [] epoch = 0 for epoch in range(n_epoch): CurrentBatch = train(CurrentBatch, epoch, ae, optE, criterionMSE, trainloader, testloader, trainfolder, Reporter) schedE.step() #%% 10. Plotting ############################## FigureDict = HelpFunc.FigureDict(trainfolder, dpi=300) MSE_loss = HelpFunc.MovingAverage(Reporter.VALS['LossMSE'], window=n_epoch) AE_TestLoss = HelpFunc.MovingAverage(Reporter.VALS['AE_TestLoss'], window=n_epoch) FigureLoss1 = plt.figure(figsize=(8, 5)) plt.plot(MSE_loss, label='MSE_loss') plt.plot(AE_TestLoss, label='AE_TestLoss') plt.legend(loc='best') plt.xlabel('Steps') plt.ylabel('Loss') plt.ylim(bottom=0) plt.title("MSE loss: train = {:.4f} , test = {:.4f}".format( MSE_loss[-1], AE_TestLoss[-1])) plt.minorticks_on()