def train(CurrentBatch, epoch, Encoder, Generator, optE, optG,
          criterionMSE, train_data, fixed_data, trainfolder,
          Reporter):
    """
    Train function to train the networks.
    
    Inputs: 
        CurrentBatch    : Batch number. \n
        epoch           : epoch number. \n
        Encoder         : Encoder network. \n
        Generator       : Generator network. \n
        optE            : Optimizer for Encoder. \n
        optG            : Optrimizer for Generator. \n
        criterionMSE    : Mean-squared Error loss. \n
        train_data      : Training DataLoader. \n
        fixed_data      : Testing DataLoader. \n
        trainfolder     : Location to store the plots. \n
        Reporter        : Reporter object to show the progress. \n
        
    Returns:
        CurrentBatch+1
    """
    
    
    Encoder.train()
    Generator.train()
    for i, (image,label) in enumerate(train_data):
        image = image.to(_DEVICE)
        label = label.to(_DEVICE)
        
        
        ## Traditional Autoencoder training loop
        optE.zero_grad()
        optG.zero_grad()
        
        one_hot_labels = torch.nn.functional.one_hot(label, num_classes).float()
        features = Encoder(image,labels= one_hot_labels)   #, mu, log_var
        recon_img   = Generator(features,one_hot_labels,mode= "one_hot")
        
        # loss = VAE_LOSS(recon_img,image, mu, log_var)
        
        loss = criterionMSE(recon_img,image)
        loss.backward()
        
        optE.step()
        optG.step()
        
        # Testing loop built-in
        if CurrentBatch % len(train_data) == 0:
            loss_ae_test= test(Encoder, Generator, fixed_data, criterionMSE, epoch)
                
        ## Plotting reconstructions
            with torch.no_grad():
                imag, label = next(iter(train_data))
                imag= imag.to(_DEVICE)
                label = label.to(_DEVICE)
                one_hot_labels = torch.nn.functional.one_hot(label, num_classes).float()
                feat = Encoder(imag,labels= one_hot_labels) #,_,_ 
                recon  = Generator(feat,one_hot_labels,mode= "one_hot")
                
                img_title_list = ["Real","Recon"]
                PlotTitle = "train_epoch_"+str(epoch)
                FigureDict = HelpFunc.FigureDict(os.path.join(trainfolder,"train_plots"),dpi =300 )
                PlotViz(imag, recon, img_title_list, FigureDict, 
                        PlotTitle, "cifar10")
            
                Encoder.eval()
                Generator.eval()
                
                image,label = next(iter(fixed_data))
                image = image.to(_DEVICE)
                label = label.to(_DEVICE)
                one_hot_labels = torch.nn.functional.one_hot(label, num_classes).float()
                feat        = Encoder(image,labels= one_hot_labels)#,_,_
                recon_image = Generator(feat,one_hot_labels, mode= "one_hot")
                
                img_title_list = ["Real","Recon"]
                PlotTitle = "test_epoch_"+str(epoch)
                FigureDict = HelpFunc.FigureDict(os.path.join(trainfolder,"test_plots"),dpi =300) 
                PlotViz(image, recon_image, img_title_list, FigureDict, 
                        PlotTitle, "cifar10")
                
                ## Reconstructions on fixed Noise
                fixed_image = Generator(fixed_noise,fixed_labels, mode ="one_hot")
                FigureDict = HelpFunc.FigureDict(os.path.join(trainfolder,"noise_recn"),dpi =300 )
                fig_ae  = plt.figure(figsize=(15,15),dpi=300)
                if fig_ae is not None:
                    fig_ae.suptitle("Reconstructions on Fixed noise and labels")
                
                for i in range(100):
                    ax=plt.subplot(10,10,i+1)
                    imshow(fixed_image[i],'cifar10')
                    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
                    plt.xticks([])
                    plt.yticks([])
                    ax.set_title("{}".format(class_labels[torch.argmax(fixed_labels[i])]) )
                FigureDict.StoreFig(fig=fig_ae, name=PlotTitle,saving=True)
                plt.show()
        
        ## Reporting   
        Reporter.DUMPDICT['Lr_Enc'].append(optE.param_groups[0]['lr'])
        Reporter.DUMPDICT['Lr_Gen'].append(optG.param_groups[0]['lr'])
        
        Reporter.SetValues([epoch+1, CurrentBatch,
                            (time.time()-Reporter.DUMPDICT["starttime"])/60,
                            loss.item(), loss_ae_test])
        CurrentBatch+=1
    return CurrentBatch
Beispiel #2
0
def train(CurrentBatch, epoch, ae, optE, criterionMSE, train_data, fixed_data,
          trainfolder, Reporter):
    """
    Train function to train the networks.
    
    Inputs: 
        CurrentBatch    : Batch number. \n
        epoch           : epoch number. \n
        Encoder         : Encoder network. \n
        Generator       : Generator network. \n
        optE            : Optimizer for Encoder. \n
        optG            : Optrimizer for Generator. \n
        criterionMSE    : Mean-squared Error loss. \n
        train_data      : Training DataLoader. \n
        fixed_data      : Testing DataLoader. \n
        trainfolder     : Location to store the plots. \n
        Reporter        : Reporter object to show the progress. \n
        
    Returns:
        CurrentBatch+1
    """

    for i, (image, label) in enumerate(train_data):
        ae.train()
        image = image.to(_DEVICE)
        label = label.to(_DEVICE)

        ## Traditional Autoencoder training loop
        optE.zero_grad()
        one_hot_labels = torch.nn.functional.one_hot(label,
                                                     num_classes).float()
        recon_img = ae(image, one_hot_labels, mode="one_hot")
        mse_loss = criterionMSE(recon_img, image)
        mse_loss.backward()

        optE.step()

        # Testing loop built-in
        if CurrentBatch % 500 == 0:
            loss_ae_test = test(ae, fixed_data, criterionMSE, epoch)

            ## Plotting reconstructions
            with torch.no_grad():
                ae.train()
                imag, label = plot_data["image"], plot_data["label"]
                one_hot_labels = torch.nn.functional.one_hot(
                    label, num_classes).float()
                recon = ae.forward(imag,
                                   one_hot_labels,
                                   mode="one_hot",
                                   train_mode="ae")

                img_title_list = ["Real", "Recon"]
                PlotTitle = "train_epoch_" + str(epoch)
                FigureDict = HelpFunc.FigureDict(os.path.join(
                    trainfolder, "train_plots"),
                                                 dpi=300)
                PlotViz(imag, recon, img_title_list, FigureDict, PlotTitle,
                        "cifar10")

                ae.eval()
                image, label = plot_data['imageT'], plot_data['labelT']
                one_hot_labels = torch.nn.functional.one_hot(
                    label, num_classes).float()
                recon_image = ae.forward(image,
                                         one_hot_labels,
                                         mode="one_hot",
                                         train_mode="ae")

                img_title_list = ["Real", "Recon"]
                PlotTitle = "test_epoch_" + str(epoch)
                FigureDict = HelpFunc.FigureDict(os.path.join(
                    trainfolder, "test_plots"),
                                                 dpi=300)
                PlotViz(image, recon_image, img_title_list, FigureDict,
                        PlotTitle, "cifar10")
        else:
            loss_ae_test = loss_test[-1]

        loss_test.append(loss_ae_test)

        ## Reporting
        Reporter.DUMPDICT['Lr_AE'].append(optE.param_groups[0]['lr'])

        Reporter.SetValues([
            epoch + 1, CurrentBatch,
            (time.time() - Reporter.DUMPDICT["starttime"]) / 60,
            mse_loss.item(), loss_test[-1]
        ])
        CurrentBatch += 1
    return CurrentBatch
Reporter.DUMPDICT["starttime"]      =   time.time()
Reporter.DUMPDICT['Lr_Enc']         =   []
Reporter.DUMPDICT['Lr_Gen']         =   []

epoch = 0
for epoch in range(n_epoch):
    CurrentBatch=train(CurrentBatch, epoch, enc, gen, optE, optG, criterionMSE,
                   trainloader, testloader ,trainfolder, Reporter)
    schedE.step()
    schedG.step()
    
    
#%% Plotting 

##############################   
FigureDict = HelpFunc.FigureDict(os.path.join(trainfolder,"AE_metrics_Plots"),dpi =300 )
loss= HelpFunc.MovingAverage(Reporter.VALS['Mse_loss'],window=n_epoch)
AE_TestLoss= HelpFunc.MovingAverage(Reporter.VALS['AE_TestLoss'],window=n_epoch)      


FigureLoss1=plt.figure(figsize=(8,5))
plt.plot(loss, label = "MSE_loss")  
plt.plot(AE_TestLoss, label='AE_TestLoss')    
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim(bottom =0 )
plt.title("MSE loss: train = {:.5f} , test = {:.5f}".format(loss[-1],AE_TestLoss[-1]))

plt.minorticks_on()
FigureDict.StoreFig(fig=FigureLoss1, name="Autoencoder_loss", saving=True)
Beispiel #4
0
                              step=ReporterStep)
CurrentBatch = 0

Reporter.DUMPDICT["starttime"] = time.time()
Reporter.DUMPDICT['Lr_AE'] = []

epoch = 0
for epoch in range(n_epoch):
    CurrentBatch = train(CurrentBatch, epoch, ae, optE, criterionMSE,
                         trainloader, testloader, trainfolder, Reporter)
    schedE.step()

#%% 10. Plotting

##############################
FigureDict = HelpFunc.FigureDict(trainfolder, dpi=300)

MSE_loss = HelpFunc.MovingAverage(Reporter.VALS['LossMSE'], window=n_epoch)
AE_TestLoss = HelpFunc.MovingAverage(Reporter.VALS['AE_TestLoss'],
                                     window=n_epoch)

FigureLoss1 = plt.figure(figsize=(8, 5))
plt.plot(MSE_loss, label='MSE_loss')
plt.plot(AE_TestLoss, label='AE_TestLoss')
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim(bottom=0)
plt.title("MSE loss: train = {:.4f} , test = {:.4f}".format(
    MSE_loss[-1], AE_TestLoss[-1]))
plt.minorticks_on()