예제 #1
0
파일: DC_GAN.py 프로젝트: 95bhargav/DC-GAN
def train(dataloader, fixed_noise, netD, netG, criterion, optimizerD,
          optimizerG, CurrentBatch, epoch, Reporter):

    for i, data in enumerate(dataloader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        optimizerD.zero_grad()
        real_cpu = data[0].to(_DEVICE)
        batch_size = real_cpu.size(0)
        randn_lab = float(torch.randn(1, ).uniform_(0.9, 1.0))
        label = torch.full((batch_size, ),
                           real_label * randn_lab,
                           dtype=real_cpu.dtype,
                           device=_DEVICE)

        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()

        #Accuracy of Disc on real images

        real_preds = torch.round(output)
        real_lab = torch.round(label)
        correct_real = (real_preds == real_lab).sum().float().item()
        d_accuracy_real = (correct_real * 100 / real_lab.size(0))

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=_DEVICE)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()

        #Accuracy of Disc on fake images

        fake_preds = torch.round(output)
        fake_lab = torch.round(label)
        correct_fake = (fake_preds == fake_lab).sum().float().item()
        d_accuracy_fake = (correct_fake * 100 / fake_lab.size(0))

        #Overall Accuracy of discriminator
        d_accuracy = (correct_real + correct_fake) * 100 / (real_lab.size(0) +
                                                            fake_lab.size(0))

        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################

        optimizerG.zero_grad()
        output = netD(fake)
        label.fill_(real_label)  # fake labels are real for generator cost
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        #Appending the labels and predictions to list to use for F1 metrics
        labs = torch.cat((real_lab, fake_lab))
        preds = torch.cat((real_preds, fake_preds))

        Reporter.DumpValues['Disc_labels'].append(labs)
        Reporter.DumpValues['Disc_preds'].append(preds)
        Reporter.DumpValues['Lr_disc'].append(optimizerD.param_groups[0]['lr'])
        Reporter.DumpValues['Lr_gen'].append(optimizerG.param_groups[0]['lr'])

        Reporter.Store([
            epoch + 1, Reporter.Batch,
            errD.item(),
            errG.item(),
            errD_real.item(),
            errD_fake.item(), d_accuracy, d_accuracy_real, d_accuracy_fake
        ])

        #plotting images side by side
        if Reporter.Batch % len(dataloader) == 0:
            with torch.no_grad():
                fake_img = netG(fixed_noise)
            PlotTitle = "train_epoch_" + str(epoch)
            FigureDict = HelpFunc.FigureStorage(os.path.join(
                trainfolder, "train_plots_DC_GAN"),
                                                dpi=300,
                                                autosave=True)
            PlotViz(fake_img, FigureDict, PlotTitle)

        CurrentBatch += 1
    return CurrentBatch
예제 #2
0
파일: DC_GAN.py 프로젝트: 95bhargav/DC-GAN
Reporter.DumpValues["Lr_disc"] = []
Reporter.DumpValues["Lr_gen"] = []
Reporter.Batch = 0
epoch = 0
for epoch in range(n_epochs):
    CurrentBatch = train(trainloader, fixed_noise, netD, netG, criterion,
                         optimizerD, optimizerG, Reporter.Batch, epoch,
                         Reporter)
    schedD.step()
    schedG.step()

#%% 7. Plotting

print("plotting....")
FigureDict = HelpFunc.FigureStorage(os.path.join(trainfolder, "GAN_plots"),
                                    dpi=300,
                                    autosave=True)
############################

disc_lr = HelpFunc.MovingAverage(Reporter.DumpValues['Lr_disc'],
                                 window=n_epochs)
gen_lr = HelpFunc.MovingAverage(Reporter.DumpValues['Lr_gen'], window=n_epochs)

FigureLoss = plt.figure(figsize=(8, 4))
plt.plot(disc_lr, label='Disc_lr')
plt.plot(gen_lr, label='Gen_lr')
plt.xlabel('Steps')
plt.ylabel("Loss")
plt.legend(loc=1)
plt.ylim(bottom=0)
plt.title("Standard DC GAN")