コード例 #1
0
ファイル: train.py プロジェクト: mridul911/model-zoo-1
def train(args):
	cnt=0
	for epoch in range(args.n_epoch):
		
		for batch in args.ds:
			cnt+=1
			valid = np.ones((len(batch), 1))
			fake = np.zeros((len(batch), 1))
			
			masked_imgs, masked_parts, _ = mask_randomly(args,batch)
			
			gen_parts = args.gen(masked_imgs)
			d_loss_real = args.dis.train_on_batch(masked_parts,valid)
			d_loss_fake = args.dis.train_on_batch(gen_parts,fake)
			d_loss = 0.5*(d_loss_real + d_loss_fake)
			
			g_loss1 = args.gan.train_on_batch(masked_imgs,valid)
			g_loss2 = args.gen.train_on_batch(masked_imgs,masked_parts)
			g_loss = g_loss1 + g_loss2
			
			with train_summary_writer.as_default():
				tf.summary.scalar("Generator loss",g_loss,step=cnt)
				tf.summary.scalar("Discriminator loss",d_loss,step=cnt)
				tf.summary.scalar("Real Discriminator loss",d_loss_real,step=cnt)
				tf.summary.scalar("Fake Discrminator loss",d_loss_fake,step=cnt)
				tf.summary.scalar("Pixel wise loss",g_loss2,step=cnt)
				tf.summary.scalar("Adverserial loss",g_loss1,step=cnt)

			if cnt%args.n_update==0:
				print('>%d, %d , g1=%0.3f, g2=%0.3f, d1=%.3f, d2=%.3f' %
				(epoch+1, cnt, g_loss1,g_loss2, d_loss_real, d_loss_fake))
				
				sample_images(args, cnt, args.valid_ds)
コード例 #2
0
    def train(self, epochs, batch_size=128, isTrain=False):
        # Adversarial ground truths
        valid = np.ones((batch_size, ))
        fake = np.zeros((batch_size, ))

        for epoch in range(epochs):
            self.log.on_epoch_begin(self.discriminator, epoch)

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, self.trn.X.shape[0], batch_size)
            imgs = self.trn.X[idx]  #imgs.shape = (128,28,28,1)

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            self.log.on_batch_begin(self.discriminator, batch_size)
            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.gan.train_on_batch(noise, valid)

            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
                  (epoch, d_loss[0], 100 * d_loss[1], g_loss))

            if isTrain:
                if epoch == (epochs - 1):  #only save model at last epoch
                    utils.save_model(self.generator, self.discriminator,
                                     self.activation, self.model_dir)

                # If at save interval => save generated image samples
                if self.do_report is not None and self.do_report(epoch):
                    utils.sample_images(self.generator, self.img_dir, epoch)

            else:  # if not training, save activity for MI calculation
                self.log.on_epoch_end(self.discriminator, d_loss_real[0],
                                      d_loss_fake[0], g_loss, epoch)
コード例 #3
0
        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        # scheduler_D.step()

        print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
              (epoch, N_EPOCHS, batch, len(dataloader), d_loss.item(),
               g_loss.item()))

        batches_done = epoch * len(dataloader) + batch
        if batches_done != 0 and batches_done % SAMPLE_INTERVAL == 0:
            sample_images(
                generator=generator,
                latent_dim=LATENT_DIM,
                n_classes=N_CLASSES,
                run_name="nsched",
                batch_count=batches_done,
            )
コード例 #4
0
        batches_done = epoch * len(train_dataloader) + i
        batches_left = args.epoch_num * len(train_dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" %
                                                        (epoch+1, args.epoch_num,
                                                        i, len(train_dataloader),
                                                        loss_D.data.cpu(), loss_G.data.cpu(),
                                                        loss_GAN.data.cpu(), loss_cycle.data.cpu(),
                                                        loss_identity.data.cpu(), time_left))

        # If at sample interval save image
        if batches_done % args.sample_interval == 0:
            sample_images(args,G__AB,G__BA, test_dataloader, epoch, batches_done)




    # Update learning rates
    lr_scheduler_G.step(epoch)
    lr_scheduler_D_B.step(epoch)
    lr_scheduler_D_A.step(epoch)

    if args.checkpoint_interval != -1 and epoch % args.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G__AB.state_dict(), '%s/%s/G__AB_%d.pth' % (args.model_result_dir, args.dataset_name, epoch))
        torch.save(G__BA.state_dict(), '%s/%s/G__BA_%d.pth' % (args.model_result_dir, args.dataset_name, epoch))
        torch.save(D_A.state_dict(), '%s/%s/D__A_%d.pth' % (args.model_result_dir, args.dataset_name, epoch))
        torch.save(D_B.state_dict(), '%s/%s/D__B_%d.pth' % (args.model_result_dir, args.dataset_name, epoch))
コード例 #5
0
def train_srcnns(train_loader, val_loader, model, device, args):

    # Loss Function #
    criterion = nn.L1Loss()

    # Optimizers #
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=(0.5, 0.999))
    optimizer_scheduler = get_lr_scheduler(optimizer=optimizer, args=args)

    # Lists #
    losses = list()

    # Train #
    print("Training {} started with total epoch of {}.".format(
        str(args.model).upper(), args.num_epochs))

    for epoch in range(args.num_epochs):
        for i, (high, low) in enumerate(train_loader):

            # Data Preparation #
            high = high.to(device)
            low = low.to(device)

            # Forward Data #
            generated = model(low)

            # Calculate Loss #
            loss = criterion(generated, high)

            # Initialize Optimizer #
            optimizer.zero_grad()

            # Back Propagation and Update #
            loss.backward()
            optimizer.step()

            # Add items to Lists #
            losses.append(loss.item())

            # Print Statistics #
            if (i + 1) % args.print_every == 0:
                print("{} | Epoch [{}/{}] | Iterations [{}/{}] | Loss {:.4f}".
                      format(
                          str(args.model).upper(), epoch + 1, args.num_epochs,
                          i + 1, len(train_loader), np.average(losses)))

                # Save Sample Images #
                sample_images(val_loader, args.batch_size, args.upscale_factor,
                              model, epoch, args.samples_path, device)

        # Adjust Learning Rate #
        optimizer_scheduler.step()

        # Save Model Weights and Inference #
        if (epoch + 1) % args.save_every == 0:
            torch.save(
                model.state_dict(),
                os.path.join(
                    args.weights_path,
                    '{}_Epoch_{}.pkl'.format(model.__class__.__name__,
                                             epoch + 1)))
            inference(val_loader, model, args.upscale_factor, epoch,
                      args.inference_path, device)
コード例 #6
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader(
        purpose='train', batch_size=config.batch_size)
    test_horse_loader, test_zebra_loader = get_horse2zebra_loader(
        purpose='test', batch_size=config.val_batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Prepare Networks #
    D_A = Discriminator()
    D_B = Discriminator()
    G_A2B = Generator()
    G_B2A = Generator()

    networks = [D_A, D_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()
    criterion_Identity = nn.L1Loss()

    # Optimizers #
    D_A_optim = torch.optim.Adam(D_A.parameters(),
                                 lr=config.lr,
                                 betas=(0.5, 0.999))
    D_B_optim = torch.optim.Adam(D_B.parameters(),
                                 lr=config.lr,
                                 betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_A_optim_scheduler = get_lr_scheduler(D_A_optim)
    D_B_optim_scheduler = get_lr_scheduler(D_B_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses_A, D_losses_B, G_losses = [], [], []

    # Training #
    print("Training CycleGAN started with total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):
        for i, (horse,
                zebra) in enumerate(zip(train_horse_loader,
                                        train_zebra_loader)):

            # Data Preparation #
            real_A = horse.to(device)
            real_B = zebra.to(device)

            # Initialize Optimizers #
            G_optim.zero_grad()
            D_A_optim.zero_grad()
            D_B_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss #
            fake_A = G_B2A(real_B)
            prob_fake_A = D_A(fake_A)
            real_labels = torch.ones(prob_fake_A.size()).to(device)
            G_mse_loss_B2A = criterion_Adversarial(prob_fake_A, real_labels)

            fake_B = G_A2B(real_A)
            prob_fake_B = D_B(fake_B)
            real_labels = torch.ones(prob_fake_B.size()).to(device)
            G_mse_loss_A2B = criterion_Adversarial(prob_fake_B, real_labels)

            # Identity Loss #
            identity_A = G_B2A(real_A)
            G_identity_loss_A = config.lambda_identity * criterion_Identity(
                identity_A, real_A)

            identity_B = G_A2B(real_B)
            G_identity_loss_B = config.lambda_identity * criterion_Identity(
                identity_B, real_B)

            # Cycle Loss #
            reconstructed_A = G_B2A(fake_B)
            G_cycle_loss_ABA = config.lambda_cycle * criterion_Cycle(
                reconstructed_A, real_A)

            reconstructed_B = G_A2B(fake_A)
            G_cycle_loss_BAB = config.lambda_cycle * criterion_Cycle(
                reconstructed_B, real_B)

            # Calculate Total Generator Loss #
            G_loss = G_mse_loss_B2A + G_mse_loss_A2B + G_identity_loss_A + G_identity_loss_B + G_cycle_loss_ABA + G_cycle_loss_BAB

            # Back Propagation and Update #
            G_loss.backward(retain_graph=True)
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            ## Train Discriminator A ##
            # Real Loss #
            prob_real_A = D_A(real_A)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Fake Loss #
            prob_fake_A = D_A(fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels)

            # Calculate Total Discriminator A Loss #
            D_loss_A = config.lambda_identity * (D_real_loss_A +
                                                 D_fake_loss_A).mean()

            # Back propagation and Update #
            D_loss_A.backward(retain_graph=True)
            D_A_optim.step()

            ## Train Discriminator B ##
            # Real Loss #
            prob_real_B = D_B(real_B)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Fake Loss #
            prob_fake_B = D_B(fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            loss_fake_B = criterion_Adversarial(prob_fake_B, fake_labels)

            # Calculate Total Discriminator B Loss #
            D_loss_B = config.lambda_identity * (loss_real_B +
                                                 loss_fake_B).mean()

            # Back propagation and Update #
            D_loss_B.backward(retain_graph=True)
            D_B_optim.step()

            # Add items to Lists #
            D_losses_A.append(D_loss_A.item())
            D_losses_B.append(D_loss_B.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "CycleGAN | Epoch [{}/{}] | Iterations [{}/{}] | D_A Loss {:.4f} | D_B Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses_A), np.average(D_losses_B),
                            np.average(G_losses)))

                # Save Sample Images #
                sample_images(test_horse_loader, test_zebra_loader, G_A2B,
                              G_B2A, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_A_optim_scheduler.step()
        D_B_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G_A2B.state_dict(),
                os.path.join(
                    config.weights_path,
                    'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(
                    config.weights_path,
                    'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("CycleGAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses_A, D_losses_B, G_losses, config.num_epochs,
                config.plots_path)

    print("Training finished.")
コード例 #7
0
def train_srgans(train_loader, val_loader, generator, discriminator, device,
                 args):

    # Loss Function #
    criterion_Perceptual = PerceptualLoss(args.model).to(device)

    # For SRGAN #
    criterion_MSE = nn.MSELoss()
    criterion_TV = TVLoss()

    # For ESRGAN #
    criterion_BCE = nn.BCEWithLogitsLoss()
    criterion_Content = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(discriminator.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999))
    G_optim = torch.optim.Adam(generator.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim, args)
    G_optim_scheduler = get_lr_scheduler(G_optim, args)

    # Lists #
    D_losses, G_losses = list(), list()

    # Train #
    print("Training {} started with total epoch of {}.".format(
        str(args.model).upper(), args.num_epochs))

    for epoch in range(args.num_epochs):
        for i, (high, low) in enumerate(train_loader):

            discriminator.train()
            if args.model == "srgan":
                generator.train()

            # Data Preparation #
            high = high.to(device)
            low = low.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad(discriminator, requires_grad=True)

            # Generate Fake HR Images #
            fake_high = generator(low)

            if args.model == 'srgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high.detach())

                # Calculate Total Discriminator Loss #
                D_loss = 1 - prob_real.mean() + prob_fake.mean()

            elif args.model == 'esrgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high.detach())

                # Relativistic Discriminator #
                diff_r2f = prob_real - prob_fake.mean()
                diff_f2r = prob_fake - prob_real.mean()

                # Labels #
                real_labels = torch.ones(diff_r2f.size()).to(device)
                fake_labels = torch.zeros(diff_f2r.size()).to(device)

                # Adversarial Loss #
                D_loss_real = criterion_BCE(diff_r2f, real_labels)
                D_loss_fake = criterion_BCE(diff_f2r, fake_labels)

                # Calculate Total Discriminator Loss #
                D_loss = (D_loss_real + D_loss_fake).mean()

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            set_requires_grad(discriminator, requires_grad=False)

            if args.model == 'srgan':

                # Adversarial Loss #
                prob_fake = discriminator(fake_high).mean()
                G_loss_adversarial = torch.mean(1 - prob_fake)
                G_loss_mse = criterion_MSE(fake_high, high)

                # Perceptual Loss #
                lambda_perceptual = 6e-3
                G_loss_perceptual = criterion_Perceptual(fake_high, high)

                # Total Variation Loss #
                G_loss_tv = criterion_TV(fake_high)

                # Calculate Total Generator Loss #
                G_loss = args.lambda_adversarial * G_loss_adversarial + G_loss_mse + lambda_perceptual * G_loss_perceptual + args.lambda_tv * G_loss_tv

            elif args.model == 'esrgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high)

                # Relativistic Discriminator #
                diff_r2f = prob_real - prob_fake.mean()
                diff_f2r = prob_fake - prob_real.mean()

                # Labels #
                real_labels = torch.ones(diff_r2f.size()).to(device)
                fake_labels = torch.zeros(diff_f2r.size()).to(device)

                # Adversarial Loss #
                G_loss_bce_real = criterion_BCE(diff_f2r, real_labels)
                G_loss_bce_fake = criterion_BCE(diff_r2f, fake_labels)

                G_loss_bce = (G_loss_bce_real + G_loss_bce_fake).mean()

                # Perceptual Loss #
                lambda_perceptual = 1e-2
                G_loss_perceptual = criterion_Perceptual(fake_high, high)

                # Content Loss #
                G_loss_content = criterion_Content(fake_high, high)

                # Calculate Total Generator Loss #
                G_loss = args.lambda_bce * G_loss_bce + lambda_perceptual * G_loss_perceptual + args.lambda_content * G_loss_content

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % args.print_every == 0:
                print(
                    "{} | Epoch [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(
                        str(args.model).upper(), epoch + 1, args.num_epochs,
                        i + 1, len(train_loader), np.average(D_losses),
                        np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, args.batch_size, args.scale_factor,
                              generator, epoch, args.samples_path, device)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights and Inference #
        if (epoch + 1) % args.save_every == 0:
            torch.save(
                generator.state_dict(),
                os.path.join(
                    args.weights_path,
                    '{}_Epoch_{}.pkl'.format(generator.__class__.__name__,
                                             epoch + 1)))
            inference(val_loader, generator, args.upscale_factor, epoch,
                      args.inference_path, device)
コード例 #8
0
ファイル: model.py プロジェクト: colipsoLocker/GAN
    d_loss = 0.5 * np.add(dA_loss, dB_loss)

    discriminator_A.trainable = False
    discriminator_B.trainable = False
    g_loss = combinedModel.train_on_batch(x, y)

    elapsed_time = datetime.datetime.now() - start_time

    print ("[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                        % ( epoch, epochs,
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[0],
                                                                            np.mean(g_loss[1:3]),
                                                                            np.mean(g_loss[3:5]),
                                                                            np.mean(g_loss[5:6]),
                                                                            elapsed_time))

    if epoch % sample_interval == 0:
        numA = np.random.randint(0, len(valImgsA))
        numB = np.random.randint(0, len(valImgsB))
        sample_images(img_A=(valImgsA[numA] + 1.0) * 127.5,
                      img_B=(valImgsB[numB] + 1.0) * 127.5,
                      g_AB=generator_A2B,
                      g_BA=generator_B2A,
                      epoch=epoch)

generator_A2B.save('./GAN/models/generator_A2B.h5')
generator_B2A.save('./GAN/models/generator_B2A.h5')
discriminator_A.save('./GAN/models/discriminator_A.h5')
discriminator_B.save('./GAN/models/discriminator_B.h5')
combinedModel.save('./GAN/models/combinedModel.h5')
コード例 #9
0
        loss_G.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()

        pred_real = discriminator(real_B)
        loss_real = criterion_GAN(pred_real, valid)
        pred_fake = discriminator(fake_B.detach())
        loss_fake = criterion_GAN(pred_fake, fake)

        loss_D = 0.5 * (loss_real + loss_fake)
        loss_D.backward()
        optimizer_D.step()

        message = (
            "\r[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}, pixel: {}, adv: {}]"
            .format(epoch, opt["n_epochs"], i, 25000 // opt["batch_size"],
                    loss_D.item(), loss_G.item(), loss_pixel.item(),
                    loss_GAN.item()))
        print(message)
        logger.info(message)

        if i % opt["sample_interval"] == 0:
            sample_images(data, i, generator, "{}-{}".format(epoch, i))

    if opt['checkpoint_interval'] != -1 and epoch % opt[
            'checkpoint_interval'] == 0:
        torch.save(generator.state_dict(), 'saved_models/generator.pth')
        torch.save(discriminator.state_dict(),
                   'saved_models/discriminator.pth')
コード例 #10
0
                               betas=(args.b1, args.b2))

if args.epoch != 0:
    # Load pretrained models

    pretrained_path = "saved_models/%s/multi_models_%d.pth" % (
        args.experiment_name, args.epoch)
    checkpoint = torch.load(pretrained_path)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])

    if args.is_eval:
        sample_images(args.experiment_name, val_dataloader, generator,
                      args.epoch, device)
        evaluate_generated_signal_quality(val_dataloader, generator, None,
                                          args.epoch, device)
        sys.exit()
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

os.makedirs("saved_models/%s" % args.experiment_name, exist_ok=True)
os.makedirs("sample_signals/%s" % args.experiment_name, exist_ok=True)
os.makedirs("logs/%s" % args.experiment_name, exist_ok=True)
writer = SummaryWriter("logs/%s" % args.experiment_name)

# ----------
#  Training
コード例 #11
0
        # --------------
        #  Log Progress
        # --------------
        # Determine approximate time left
        batches_done = epoch * len(train_dataloader) + i
        batches_left = args.epoch_num * len(train_dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left *
                                       (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch%d/%d]-[Batch%d/%d]-[Dloss:%f]-[Gloss:%f, loss_pixel:%f, adv:%f] ETA:%s"
            % (epoch + 1, args.epoch_num, i, len(train_dataloader),
               loss_D.data.cpu(), loss_G.data.cpu(), loss_pixel.data.cpu(),
               loss_GAN.data.cpu(), time_left))

        # If at sample interval save image
        if batches_done % args.sample_interval == 0:
            sample_images(generator, test_dataloader, args, epoch,
                          batches_done)

    if args.checkpoint_interval != -1 and epoch % args.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(
            generator.state_dict(), '%s/%s/generator_%d.pth' %
            (args.model_result_dir, args.dataset_name, epoch))
        torch.save(
            discriminator.state_dict(), '%s/%s/discriminator_%d.pth' %
            (args.model_result_dir, args.dataset_name, epoch))
コード例 #12
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_edges2shoes_loader('train', config.batch_size)
    val_loader = get_edges2shoes_loader('val', config.val_batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)
    D_A = Discriminator().to(device)
    D_B = Discriminator().to(device)

    # Loss Function #
    criterion_Adversarial = nn.BCELoss()
    criterion_Recon = nn.MSELoss()
    criterion_Feature = nn.HingeEmbeddingLoss()

    # Optimizers #
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.00001)
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()),
                               config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.00001)

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, G_losses = [], []

    # Constants #
    iters = 0

    # Training #
    print("Training DiscoGAN started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(train_loader):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Models #
            G_A2B.zero_grad()
            G_B2A.zero_grad()
            D_A.zero_grad()
            D_B.zero_grad()

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ################
            # Forward Data #
            ################

            fake_B = G_A2B(real_A)
            fake_A = G_B2A(real_B)

            prob_real_A, A_real_features = D_A(real_A)
            prob_fake_A, A_fake_features = D_A(fake_A)

            prob_real_B, B_real_features = D_B(real_B)
            prob_fake_B, B_fake_features = D_B(fake_B)

            #######################
            # Train Discriminator #
            #######################

            # Discriminator A #
            real_labels = Variable(torch.ones(prob_real_A.size()),
                                   requires_grad=False).to(device)
            D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            fake_labels = Variable(torch.zeros(prob_fake_A.size()),
                                   requires_grad=False).to(device)
            D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels)

            D_loss_A = (D_real_loss_A + D_fake_loss_A).mean()

            # Discriminator B #
            real_labels = Variable(torch.ones(prob_real_B.size()),
                                   requires_grad=False).to(device)
            D_real_loss_B = criterion_Adversarial(prob_real_B, real_labels)

            fake_labels = Variable(torch.zeros(prob_fake_B.size()),
                                   requires_grad=False).to(device)
            D_fake_loss_B = criterion_Adversarial(prob_fake_B, fake_labels)

            D_loss_B = (D_real_loss_B + D_fake_loss_B).mean()

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            ###################
            # Train Generator #
            ###################

            # Adversarial Loss #
            real_labels = Variable(torch.ones(prob_real_A.size()),
                                   requires_grad=False).to(device)
            G_adv_loss_A = criterion_Adversarial(prob_fake_A, real_labels)

            real_labels = Variable(torch.ones(prob_real_B.size()),
                                   requires_grad=False).to(device)
            G_adv_loss_B = criterion_Adversarial(prob_fake_B, real_labels)

            # Feature Loss #
            G_feature_loss_A = feature_loss(criterion_Feature, A_real_features,
                                            A_fake_features)
            G_feature_loss_B = feature_loss(criterion_Feature, B_real_features,
                                            B_fake_features)

            # Reconstruction Loss #
            fake_ABA = G_B2A(fake_B)
            fake_BAB = G_A2B(fake_A)

            G_recon_loss_A = criterion_Recon(fake_ABA, real_A)
            G_recon_loss_B = criterion_Recon(fake_BAB, real_B)

            if iters < config.decay_gan_loss:
                rate = config.starting_rate
            else:
                print("Now the rate is changed to {}".format(
                    config.changed_rate))
                rate = config.changed_rate

            G_loss_A = (G_adv_loss_A * 0.1 + G_feature_loss_A * 0.9) * (
                1. - rate) + G_recon_loss_A * rate
            G_loss_B = (G_adv_loss_B * 0.1 + G_feature_loss_B * 0.9) * (
                1. - rate) + G_recon_loss_B * rate

            # Calculate Total Generator Loss #
            G_loss = G_loss_A + G_loss_B

            # Back Propagation and Update #
            if iters % config.num_train_gen == 0:
                D_loss.backward()
                D_optim.step()
            else:
                G_loss.backward()
                G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "DiscoGAN | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G_A2B, G_B2A, epoch,
                              config.samples_path)

            iters += 1

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G_A2B.state_dict(),
                os.path.join(
                    config.weights_path,
                    'DiscoGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(
                    config.weights_path,
                    'DiscoGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train('DiscoGAN', config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
コード例 #13
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_facades_loader('train', config.batch_size)
    val_loader = get_facades_loader('val', config.val_batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    D = Discriminator().to(device)
    G = Generator().to(device)

    # Criterion #
    criterion_Adversarial = nn.BCELoss()
    criterion_Pixelwise = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(D.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, G_losses = [], []

    # Training #
    print("Training Pix2Pix started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(train_loader):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            # Prevent Discriminator Update during Generator Update #
            set_requires_grad(D, requires_grad=False)

            # Adversarial Loss #
            fake_B = G(real_A)
            prob_fake = D(fake_B, real_A)
            real_labels = torch.ones(prob_fake.size()).to(device)
            G_loss_fake = criterion_Adversarial(prob_fake, real_labels)

            # Pixel-Wise Loss #
            G_loss_pixelwise = criterion_Pixelwise(fake_B, real_B)

            # Calculate Total Generator Loss #
            G_loss = G_loss_fake + config.l1_lambda * G_loss_pixelwise

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            # Prevent Discriminator Update during Generator Update #
            set_requires_grad(D, requires_grad=True)

            # Adversarial Loss #
            prob_real = D(real_B, real_A)
            real_labels = torch.ones(prob_real.size()).to(device)
            D_real_loss = criterion_Adversarial(prob_real, real_labels)

            fake_B = G(real_A)
            prob_fake = D(fake_B.detach(), real_A)
            fake_labels = torch.zeros(prob_fake.size()).to(device)
            D_fake_loss = criterion_Adversarial(prob_fake, fake_labels)

            # Calculate Total Discriminator Loss #
            D_loss = torch.mean(D_real_loss + D_fake_loss)

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "Pix2Pix | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    config.weights_path,
                    'Pix2Pix_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train('Pix2Pix', config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.plots_path)

    print("Training finished.")
コード例 #14
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_edges2handbags_loader(purpose='train',
                                             batch_size=config.batch_size)
    val_loader = get_edges2handbags_loader(purpose='val',
                                           batch_size=config.batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    D_cVAE = Discriminator()
    D_cLR = Discriminator()
    E = Encoder(config.z_dim)
    G = Generator(config.z_dim)

    networks = [D_cVAE, D_cLR, E, G]
    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Recon = nn.L1Loss()
    criterion_Adversarial = nn.MSELoss()

    # Optimizers #
    D_cVAE_optim = torch.optim.Adam(D_cVAE.parameters(),
                                    lr=config.lr,
                                    betas=(0.5, 0.999))
    D_cLR_optim = torch.optim.Adam(D_cLR.parameters(),
                                   lr=config.lr,
                                   betas=(0.5, 0.999))
    E_optim = torch.optim.Adam(E.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_cVAE_optim_scheduler = get_lr_scheduler(D_cVAE_optim)
    D_cLR_optim_scheduler = get_lr_scheduler(D_cLR_optim)
    E_optim_scheduler = get_lr_scheduler(E_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, E_G_losses, G_losses = [], [], []

    # Fixed Noise #
    fixed_noise = torch.randn(config.test_size, config.num_images,
                              config.z_dim).to(device)

    # Training #
    print("Training BicycleGAN started total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):
        for i, (sketch, target) in enumerate(train_loader):

            # Data Preparation #
            sketch = sketch.to(device)
            target = target.to(device)

            # Separate Data for D_cVAE-GAN and D_cLR-GAN #
            cVAE_data = {
                'sketch': sketch[0].unsqueeze(dim=0),
                'target': target[0].unsqueeze(dim=0)
            }
            cLR_data = {
                'sketch': sketch[1].unsqueeze(dim=0),
                'target': target[1].unsqueeze(dim=0)
            }

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Train Discriminators #
            set_requires_grad([D_cVAE, D_cLR], requires_grad=True)

            ################################
            # Train Discriminator cVAE-GAN #
            ################################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Encode Latent Vector #
            mean, std = E(cVAE_data['target'])
            random_z = torch.randn(1, config.z_dim).to(device)
            encoded_z = mean + (random_z * std)

            # Generate Fake Image #
            fake_image_cVAE = G(cVAE_data['sketch'], encoded_z)

            # Forward to Discriminator cVAE-GAN #
            prob_real_D_cVAE_1, prob_real_D_cVAE_2 = D_cVAE(
                cVAE_data['target'])
            prob_fake_D_cVAE_1, prob_fake_D_cVAE_2 = D_cVAE(
                fake_image_cVAE.detach())

            # Adversarial Loss using cVAE_1 #
            real_labels = torch.ones(prob_real_D_cVAE_1.size()).to(device)
            D_cVAE_1_real_loss = criterion_Adversarial(prob_real_D_cVAE_1,
                                                       real_labels)

            fake_labels = torch.zeros(prob_fake_D_cVAE_1.size()).to(device)
            D_cVAE_1_fake_loss = criterion_Adversarial(prob_fake_D_cVAE_1,
                                                       fake_labels)

            D_cVAE_1_loss = D_cVAE_1_real_loss + D_cVAE_1_fake_loss

            # Adversarial Loss using cVAE_2 #
            real_labels = torch.ones(prob_real_D_cVAE_2.size()).to(device)
            D_cVAE_2_real_loss = criterion_Adversarial(prob_real_D_cVAE_2,
                                                       real_labels)

            fake_labels = torch.zeros(prob_fake_D_cVAE_2.size()).to(device)
            D_cVAE_2_fake_loss = criterion_Adversarial(prob_fake_D_cVAE_2,
                                                       fake_labels)

            D_cVAE_2_loss = D_cVAE_2_real_loss + D_cVAE_2_fake_loss

            ###########################
            # Train Discriminator cLR #
            ###########################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)

            # Forward to Discriminator cLR-GAN #
            prob_real_D_cLR_1, prob_real_D_cLR_2 = D_cLR(cLR_data['target'])
            prob_fake_D_cLR_1, prob_fake_D_cLR_2 = D_cLR(
                fake_image_cLR.detach())

            # Adversarial Loss using cLR-1 #
            real_labels = torch.ones(prob_real_D_cLR_1.size()).to(device)
            D_cLR_1_real_loss = criterion_Adversarial(prob_real_D_cLR_1,
                                                      real_labels)

            fake_labels = torch.zeros(prob_fake_D_cLR_1.size()).to(device)
            D_cLR_1_fake_loss = criterion_Adversarial(prob_fake_D_cLR_1,
                                                      fake_labels)

            D_cLR_1_loss = D_cLR_1_real_loss + D_cLR_1_fake_loss

            # Adversarial Loss using cLR-2 #
            real_labels = torch.ones(prob_real_D_cLR_2.size()).to(device)
            D_cLR_2_real_loss = criterion_Adversarial(prob_real_D_cLR_2,
                                                      real_labels)

            fake_labels = torch.zeros(prob_fake_D_cLR_2.size()).to(device)
            D_cLR_2_fake_loss = criterion_Adversarial(prob_fake_D_cLR_2,
                                                      fake_labels)

            D_cLR_2_loss = D_cLR_2_real_loss + D_cLR_2_fake_loss

            # Calculate Total Discriminator Loss #
            D_loss = D_cVAE_1_loss + D_cVAE_2_loss + D_cLR_1_loss + D_cLR_2_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_cVAE_optim.step()
            D_cLR_optim.step()

            set_requires_grad([D_cVAE, D_cLR], requires_grad=False)

            ###############################
            # Train Encoder and Generator #
            ###############################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Encode Latent Vector #
            mean, std = E(cVAE_data['target'])
            random_z = torch.randn(1, config.z_dim).to(device)
            encoded_z = mean + (random_z * std)

            # Generate Fake Image #
            fake_image_cVAE = G(cVAE_data['sketch'], encoded_z)
            prob_fake_D_cVAE_1, prob_fake_D_cVAE_2 = D_cVAE(fake_image_cVAE)

            # Adversarial Loss using cVAE #
            real_labels = torch.ones(prob_fake_D_cVAE_1.size()).to(device)
            E_G_adv_cVAE_1_loss = criterion_Adversarial(
                prob_fake_D_cVAE_1, real_labels)

            real_labels = torch.ones(prob_fake_D_cVAE_2.size()).to(device)
            E_G_adv_cVAE_2_loss = criterion_Adversarial(
                prob_fake_D_cVAE_2, real_labels)

            E_G_adv_cVAE_loss = E_G_adv_cVAE_1_loss + E_G_adv_cVAE_2_loss

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)
            prob_fake_D_cLR_1, prob_fake_D_cLR_2 = D_cLR(fake_image_cLR)

            # Adversarial Loss of cLR #
            real_labels = torch.ones(prob_fake_D_cLR_1.size()).to(device)
            E_G_adv_cLR_1_loss = criterion_Adversarial(prob_fake_D_cLR_1,
                                                       real_labels)

            real_labels = torch.ones(prob_fake_D_cLR_2.size()).to(device)
            E_G_adv_cLR_2_loss = criterion_Adversarial(prob_fake_D_cLR_2,
                                                       real_labels)

            E_G_adv_cLR_loss = E_G_adv_cLR_1_loss + E_G_adv_cLR_2_loss

            # KL Divergence with N ~ (0, 1) #
            E_KL_div_loss = config.lambda_KL * torch.sum(
                0.5 * (mean**2 + std - 2 * torch.log(std) - 1))

            # Reconstruction Loss #
            E_G_recon_loss = config.lambda_Image * criterion_Recon(
                fake_image_cVAE, cVAE_data['target'])

            # Total Encoder and Generator Loss ##
            E_G_loss = E_G_adv_cVAE_loss + E_G_adv_cLR_loss + E_KL_div_loss + E_G_recon_loss

            # Back Propagation and Update #
            E_G_loss.backward()
            E_optim.step()
            G_optim.step()

            ########################
            # Train Generator Only #
            ########################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)
            mean, std = E(fake_image_cLR)

            # Reconstruction Loss #
            G_recon_loss = criterion_Recon(mean, random_z)

            # Calculate Total Generator Loss #
            G_loss = config.lambda_Z * G_recon_loss

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            E_G_losses.append(E_G_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "BicycleGAN | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | E_G Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(E_G_losses),
                            np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G, fixed_noise, epoch,
                              config.num_images, config.samples_path)

        # Adjust Learning Rate #
        D_cVAE_optim_scheduler.step()
        D_cLR_optim_scheduler.step()
        E_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Models #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    config.weights_path,
                    'BicycleGAN_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("BicycleGAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, E_G_losses, G_losses, config.num_epochs,
                config.plots_path)

    print("Training finished.")