コード例 #1
0
def inference():

    # Inference Path #
    make_dirs(config.inference_path)

    # Prepare Data Loader #
    test_loader = get_celeba_loader('test', config.batch_size,
                                    config.selected_attrs)

    # Prepare Generator #
    G = Generator(num_classes=len(config.selected_attrs)).to(device)
    G.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'StarGAN_Generator_Epoch_{}.pkl'.format(config.num_epochs))))

    # Test #
    print("StarGAN | Generating Aligned CelebA Images started...")
    for i, (image, label) in enumerate(test_loader):

        # Prepare Data #
        image = image.to(device)
        fixed_labels = create_labels(label,
                                     selected_attrs=config.selected_attrs)

        # Generate Fake Images #
        x_fake_list = [image]

        for c_fixed in fixed_labels:
            x_fake_list.append(G(image, c_fixed))
        x_concat = torch.cat(x_fake_list, dim=3)

        # Save Images #
        save_image(denorm(x_concat.data.cpu()),
                   os.path.join(
                       config.inference_path,
                       'StarGAN_Aligned_CelebA_Results_%04d.png' % (i + 1)),
                   nrow=1,
                   padding=0)

    make_gifs_test("StarGAN", config.inference_path)
コード例 #2
0
ファイル: train.py プロジェクト: hee9joon/Face-Generation
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    celeba_loader = get_celeba_loader(path=config.celeba_path,
                                      batch_size=config.batch_size)
    total_batch = len(celeba_loader)

    # Prepare Networks #
    D = Discriminator().to(device)
    G = Generator().to(device)

    # Optimizer #
    D_optim = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                      D.parameters()),
                               lr=config.D_lr,
                               betas=(0.0, 0.9))
    G_optim = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                      G.parameters()),
                               lr=config.G_lr,
                               betas=(0.0, 0.9))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Fixed Noise #
    fixed_noise = torch.randn(config.batch_size, config.noise_dim, 1,
                              1).to(device)

    # Lists #
    D_losses, G_losses = [], []

    # Train #
    print("Training has started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):
        for i, (images, labels) in enumerate(celeba_loader):

            # Data Preparation #
            images = images.to(device)
            noise = torch.randn(config.batch_size, config.noise_dim, 1,
                                1).to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            # Hinge Loss using Real Image #
            prob_real = D(images)[0]
            D_real_loss = nn.ReLU()(1.0 - prob_real).mean()

            # Hinge Loss using Generated Image #
            fake_image = G(noise)[0]
            prob_fake = D(fake_image.detach())[0]
            D_fake_loss = nn.ReLU()(1.0 + prob_fake).mean()

            # Calculate Total Discriminator Loss #
            D_loss = D_real_loss + D_fake_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            # Hinge Loss using Generated Image #
            fake_image = G(noise)[0]
            prob_fake = D(fake_image)[0]

            # Calculate Total Generator Loss #
            G_loss = -prob_fake.mean()

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Contain Losses #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.mean(D_losses), np.mean(G_losses)))

        # Sample Images #
        sample_images(G, fixed_noise, epoch)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(config.weights_path,
                             'Face_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("Face_Generation", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
コード例 #3
0
ファイル: train.py プロジェクト: hee9joon/Face-Generation
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    celeba_loader = get_celeba_loader(path=config.celeba_path,
                                      batch_size=config.batch_size)
    total_batch = len(celeba_loader)

    # Prepare Networks #
    D = Discriminator().to(device)
    G = Generator().to(device)

    # Loss Function #
    criterion = nn.MSELoss()

    # Optimizer #
    D_optim = torch.optim.Adam(D.parameters(),
                               lr=config.D_lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.G_lr,
                               betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Fixed Noise #
    fixed_noise = torch.randn(config.batch_size, config.noise_dim, 1,
                              1).to(device)

    # Lists #
    D_losses, G_losses = [], []

    # Train #
    print("Training has started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):
        for i, (images, labels) in enumerate(celeba_loader):

            # Data Preparation #
            images = images.to(device)

            # Initialize Optimizers #
            G_optim.zero_grad()
            D_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            # Adversarial Loss using Real Image #
            _, prob_real = D(images)

            D_real_loss = criterion(prob_real, images)

            # Adversarial Loss using Fake Image #
            noise = torch.randn(config.batch_size, config.noise_dim, 1,
                                1).to(device)
            fake_images = G(noise)
            _, prob_fake = D(fake_images.detach())

            D_fake_loss = criterion(prob_fake, fake_images)
            D_fake_loss = torch.clamp(config.margin - D_fake_loss, min=0)

            # Calculate Total Discriminator Loss #
            D_loss = D_real_loss
            if D_fake_loss.item() < config.margin:
                D_loss += D_fake_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            # Adversarial Loss #
            fake_images = G(noise)
            encoded, prob_fake = D(fake_images)
            G_loss = criterion(prob_fake, fake_images)

            # Pulling Away Loss #
            G_pulling_away_loss = pulling_away(encoded)

            # Calculate Total Generator Loss #
            G_loss += config.lambda_pt * G_pulling_away_loss

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.mean(D_losses), np.mean(G_losses)))

        # Sample Images #
        sample_images(G, fixed_noise, epoch)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(config.weights_path,
                             'Face_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("Face_Generation", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
コード例 #4
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [
        config.samples_gen_path, config.samples_recon_path,
        config.weights_path, config.plots_path
    ]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    celeba_loader = get_celeba_loader(path=config.celeba_path,
                                      batch_size=config.batch_size)
    total_batch = len(celeba_loader)

    # Prepare Networks #
    D = Discriminator().to(device)
    G = Generator().to(device)

    # Loss Function #
    criterion_bce = nn.BCEWithLogitsLoss()
    criterion_l2 = nn.MSELoss()

    # Optimizer #
    Disc_optim = torch.optim.Adam(D.parameters(),
                                  lr=config.D_lr,
                                  betas=(0.5, 0.999))
    Enc_optim = torch.optim.Adam(G.encoder.parameters(),
                                 lr=config.G_lr,
                                 betas=(0.5, 0.999))
    Dec_optim = torch.optim.Adam(G.decoder.parameters(),
                                 lr=config.G_lr,
                                 betas=(0.5, 0.999))

    Disc_optim_scheduler = get_lr_scheduler(Disc_optim)
    Enc_optim_scheduler = get_lr_scheduler(Enc_optim)
    Dec_optim_scheduler = get_lr_scheduler(Dec_optim)

    # Labels #
    real_labels = torch.ones(config.batch_size, 1).to(device)
    fake_labels = torch.zeros(config.batch_size, 1).to(device)

    # Fixed Noise #
    fixed_noise = torch.randn(config.batch_size, config.noise_dim).to(device)

    # Lists #
    Disc_losses, Enc_losses, Dec_losses = list(), list(), list()

    # Train #
    print("Training has started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):
        for i, (images, labels) in enumerate(celeba_loader):

            # Data Preparation #
            images = images.to(device)

            # Initialize Optimizers #
            Disc_optim.zero_grad()
            Enc_optim.zero_grad()
            Dec_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            # Noise and Generate Fake Images #
            noise = torch.randn(images.size(0), config.noise_dim).to(device)
            fake_images, mu, log_var = G(images)

            # Discriminator GAN Loss #
            prob_real = D(images)
            Disc_real_loss = criterion_bce(prob_real, real_labels)

            prob_fake = D(fake_images)
            Disc_fake_loss = criterion_bce(prob_fake, fake_labels)

            prob_p_fake = D(G.decoder(noise))
            Disc_fake_p_loss = criterion_bce(prob_p_fake, fake_labels)

            # Calculate Total Discriminator Loss #
            Disc_loss = Disc_real_loss + Disc_fake_loss + Disc_fake_p_loss

            # Back Propagation and Update #
            Disc_loss.backward()
            Disc_optim.step()

            #################
            # Train Decoder #
            #################

            # Generate Fake Images #
            fake_images, mu, log_var = G(images)

            # Decoder GAN Loss #
            prob_real = D(images)
            Dec_real_loss = criterion_bce(prob_real, real_labels)

            prob_fake = D(fake_images)
            Dec_fake_loss = criterion_bce(prob_fake, fake_labels)

            prob_p_fake = D(G.decoder(noise))
            Dec_fake_p_loss = criterion_bce(prob_p_fake, fake_labels)

            # Calculate Total Decoder GAN Loss #
            Dec_gan_loss = -1 * (Dec_fake_loss + Dec_fake_p_loss +
                                 Dec_real_loss)

            # Decoder Reconstruction Loss #
            Dec_recon_loss = criterion_l2(D.feature(fake_images),
                                          D.feature(images))

            # Calculate Total Decoder Loss #
            Dec_loss = Dec_gan_loss + config.gamma * Dec_recon_loss

            # Back Propagation and Update #
            Dec_loss.backward()
            Dec_optim.step()

            #################
            # Train Encoder #
            #################

            # Generate Fake Images #
            fake_images, mu, log_var = G(images)

            # Encoder Prior Loss #
            Enc_prior_loss = -0.5 * torch.mean(1 + log_var - torch.pow(mu, 2) -
                                               torch.exp(log_var))

            # Encoder Reconstruction Loss #
            Enc_recon_loss = criterion_l2(D.feature(fake_images),
                                          D.feature(images))

            # Calculate Total Encoder Loss #
            Enc_loss = Enc_prior_loss + config.beta * Enc_recon_loss

            # Back Propagation and Update #
            Enc_loss.backward()
            Enc_optim.step()

            # Add items to Lists #
            Disc_losses.append(Disc_loss.item())
            Enc_losses.append(Enc_loss.item())
            Dec_losses.append(Dec_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "Epochs [{}/{}] | Iterations [{}/{}] | Disc Loss {:.4f} | Enc Loss {:.4f} | Dec Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(Disc_losses), np.average(Enc_losses),
                            np.average(Dec_losses)))

                # Sample Images #
                sample_images(G, celeba_loader, fixed_noise, epoch)

        # Adjust Learning Rate #
        Disc_optim_scheduler.step()
        Enc_optim_scheduler.step()
        Dec_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(config.weights_path,
                             'Face_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("Face_Generation", config.samples_gen_path)
    make_gifs_train("Face_Reconstruction", config.samples_recon_path)

    # Plot Losses #
    plot_losses(Disc_losses, Enc_losses, Dec_losses, config.num_epochs,
                config.plots_path)

    print("Training finished.")
コード例 #5
0
ファイル: train.py プロジェクト: hee9joon/Face-Generation
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    celeba_loader = get_celeba_loader(path=config.celeba_path,
                                      batch_size=config.batch_size)
    total_batch = len(celeba_loader)

    # Prepare Networks #
    D = Discriminator().to(device)
    G = Generator().to(device)

    # Loss Function #
    criterion = nn.L1Loss().to(device)

    # Optimizers #
    D_optim = torch.optim.Adam(D.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, G_losses = [], []

    # Fixed Noise #
    fixed_noise = torch.randn(config.batch_size, config.noise_dim, 1,
                              1).to(device)

    # Constants #
    k_t = 0
    lr_k = 0.001
    gamma = 0.7

    # Train #
    print("Training started with total epoch of {}.".format(config.num_epochs))

    for epoch in range(config.num_epochs):
        for i, (images, labels) in enumerate(celeba_loader):

            # Data Preparation #
            images = images.to(device)

            noise = torch.randn(config.batch_size, config.noise_dim, 1,
                                1).to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            # Adversarial Loss using Real Image #
            prob_real = D(images)
            D_real_loss = criterion(prob_real, images)

            # Adversarial Loss using Generated Image #
            fake_images = G(noise)
            prob_fake = D(fake_images.detach())

            D_fake_loss = criterion(prob_fake, fake_images)

            # Calculate Total Discriminator Loss #
            D_loss = D_real_loss - k_t * D_fake_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            # Adversarial Loss using Generated Image #
            fake_images = G(noise)
            prob_fake = D(fake_images)

            # Calculate Total Generator Loss #
            G_loss = criterion(prob_fake, fake_images)

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Update Constants #
            balance = (gamma * D_real_loss - G_loss).item()
            k_t += lr_k * balance
            k_t = min(max(k_t, 0), 1)

            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "Epoch [{}/{}] | Iter [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

        # Sample Images #
        sample_images(G, fixed_noise, epoch)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(config.weights_path,
                             'Face_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("Face_Generation", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
コード例 #6
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_celeba_loader('train', config.batch_size,
                                     config.selected_attrs)
    total_batch = len(train_loader)

    fixed_image, original_label = next(iter(train_loader))
    fixed_image = fixed_image.to(device)
    fixed_labels_list = create_labels(original_label, config.selected_attrs)

    # Prepare Networks #
    D = Discriminator(num_classes=len(config.selected_attrs)).to(device)
    G = Generator(num_classes=len(config.selected_attrs)).to(device)

    # Optimizers #
    D_optim = torch.optim.Adam(D.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, G_losses = [], []

    # Train #
    print("Training StarGAN started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, batch in enumerate(train_loader):

            # Data Preparation #
            real_image, label = next(iter(train_loader))

            real_image = real_image.to(device)
            label = label.to(device)

            rand_idx = torch.randperm(label.size(0))
            target_label = label[rand_idx].to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad(D, requires_grad=True)

            # Discriminiator Loss using Real Image #
            prob_real_src, prob_real_cls = D(real_image)
            D_real_loss = -torch.mean(prob_real_src)
            D_cls_loss = config.lambda_cls * criterion_CLS(
                prob_real_cls, label)

            # Discriminiator Loss using Generated Image #
            fake_image = G(real_image, target_label)
            prob_fake_src, prob_fake_cls = D(fake_image.detach())
            D_fake_loss = torch.mean(prob_fake_src)

            # Discriminiator Loss using Wasserstein GAN Gradient Penalty #
            D_gp_loss = config.lambda_gp * get_gradient_penalty(
                real_image, fake_image, D)

            # Calculate Total Discriminator Loss #
            D_loss = D_real_loss + D_fake_loss + D_cls_loss + D_gp_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())

            ###################
            # Train Generator #
            ###################

            if (i + 1) % config.n_critics == 0:

                # Prevent Discriminator Update during Generator Update #
                set_requires_grad(D, requires_grad=False)

                # Initialize Optimizers #
                D_optim.zero_grad()
                G_optim.zero_grad()

                # Generator Loss using Fake Images #
                fake_image = G(real_image, target_label)
                prob_fake_src, prob_fake_cls = D(fake_image)
                G_fake_loss = -torch.mean(prob_fake_src)
                G_cls_loss = config.lambda_cls * criterion_CLS(
                    prob_fake_cls, target_label)

                # Reconstruction Loss #
                recon_image = G(fake_image, label)
                G_recon_loss = config.lambda_recon * torch.mean(
                    torch.abs(real_image - recon_image))

                # Calculate Total Generator Loss #
                G_loss = G_fake_loss + G_recon_loss + G_cls_loss

                # Back Propagation and Update #
                G_loss.backward()
                G_optim.step()

                # Add items to Lists #
                G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "StarGAN | Epoch [{}/{}] | Iteration [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                save_samples(fixed_image, fixed_labels_list, G, epoch,
                             config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    config.weights_path,
                    'StarGAN_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train('StarGAN', config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    print("Training Finished.")