Пример #1
0
def main():
    global args
    parser = argparse.ArgumentParser(
        description=
        "Script to save generated examples from learned ClusterGAN generator")
    parser.add_argument("-r",
                        "--run_dir",
                        dest="run_dir",
                        help="Training run directory")
    parser.add_argument("-b",
                        "--batch_size",
                        dest="batch_size",
                        default=100,
                        type=int,
                        help="Batch size")
    parser.add_argument("-s",
                        "--dataset_name",
                        dest="dataset_name",
                        default='mnist',
                        choices=dataset_list,
                        help="Dataset name")
    args = parser.parse_args()

    batch_size = args.batch_size

    # Directory structure for this run
    run_dir = args.run_dir.rstrip("/")
    run_name = run_dir.split(os.sep)[-1]
    dataset_name = args.dataset_name

    run_dir = os.path.join(RUNS_DIR, dataset_name, run_name)
    data_dir = os.path.join(DATASETS_DIR, dataset_name)
    imgs_dir = os.path.join(run_dir, 'images')
    models_dir = os.path.join(run_dir, 'models')

    # Latent space info
    train_df = pd.read_csv('%s/training_details.csv' % (run_dir))
    latent_dim = train_df['latent_dim'][0]
    n_c = train_df['n_classes'][0]

    cuda = True if torch.cuda.is_available() else False

    # Load encoder model
    encoder = Encoder_CNN(latent_dim, n_c)
    enc_fname = os.path.join(models_dir, encoder.name + '.pth.tar')
    encoder.load_state_dict(torch.load(enc_fname))
    encoder.cuda()
    encoder.eval()

    # Load generator model
    x_shape = (1, 28, 28)
    generator = Generator_CNN(latent_dim, n_c, x_shape)
    gen_fname = os.path.join(models_dir, generator.name + '.pth.tar')
    generator.load_state_dict(torch.load(gen_fname))
    generator.cuda()
    generator.eval()

    # Loop through specific classes
    for idx in range(n_c):
        zn, zc, zc_idx = sample_z(shape=batch_size,
                                  latent_dim=latent_dim,
                                  n_c=n_c,
                                  fix_class=idx,
                                  req_grad=False)

        # Generate a batch of images
        gen_imgs = generator(zn, zc)

        # Save some examples!
        save_image(gen_imgs.data,
                   '%s/class_%i_gen.png' % (imgs_dir, idx),
                   nrow=int(np.sqrt(batch_size)),
                   normalize=True)

        enc_zn, enc_zc, enc_zc_logits = encoder(gen_imgs)

        # Generate a batch of images
        gen_imgs = generator(enc_zn, enc_zc)

        # Save some examples!
        save_image(gen_imgs.data,
                   '%s/class_enc_%i_gen.png' % (imgs_dir, idx),
                   nrow=int(np.sqrt(batch_size)),
                   normalize=True)
        enc_zn, enc_zc, enc_zc_logits = encoder(gen_imgs)
Пример #2
0
def main():
    global args
    parser = argparse.ArgumentParser(
        description="Convolutional NN Training Script")
    parser.add_argument("-r",
                        "--run_name",
                        dest="run_name",
                        default='clusgan',
                        help="Name of training run")
    parser.add_argument("-n",
                        "--n_epochs",
                        dest="n_epochs",
                        default=200,
                        type=int,
                        help="Number of epochs")
    parser.add_argument("-b",
                        "--batch_size",
                        dest="batch_size",
                        default=64,
                        type=int,
                        help="Batch size")
    parser.add_argument("-s",
                        "--dataset_name",
                        dest="dataset_name",
                        default='mnist',
                        choices=dataset_list,
                        help="Dataset name")
    args = parser.parse_args()

    run_name = args.run_name
    dataset_name = args.dataset_name

    # Make directory structure for this run
    run_dir = os.path.join(RUNS_DIR, dataset_name, run_name)
    data_dir = os.path.join(DATASETS_DIR, dataset_name)
    imgs_dir = os.path.join(run_dir, 'images')
    models_dir = os.path.join(run_dir, 'models')

    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(run_dir, exist_ok=True)
    os.makedirs(imgs_dir, exist_ok=True)
    os.makedirs(models_dir, exist_ok=True)
    print('\nResults to be saved in directory %s\n' % (run_dir))

    # Training details
    n_epochs = args.n_epochs
    batch_size = args.batch_size
    test_batch_size = 5000
    lr = 1e-4
    b1 = 0.5
    b2 = 0.9  #99
    decay = 2.5 * 1e-5
    n_skip_iter = 1  #5

    img_size = 28
    channels = 1

    # Latent space info
    latent_dim = 30
    n_c = 10
    betan = 10
    betac = 10

    # Wasserstein metric flag
    wass_metric = True
    #wass_metric = False

    x_shape = (channels, img_size, img_size)

    cuda = True if torch.cuda.is_available() else False
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Loss function
    bce_loss = torch.nn.BCELoss()
    xe_loss = torch.nn.CrossEntropyLoss()
    mse_loss = torch.nn.MSELoss()

    # Initialize generator and discriminator
    generator = Generator_CNN(latent_dim, n_c, x_shape)
    encoder = Encoder_CNN(latent_dim, n_c)
    discriminator = Discriminator_CNN(wass_metric=wass_metric)

    if cuda:
        generator.cuda()
        encoder.cuda()
        discriminator.cuda()
        bce_loss.cuda()
        xe_loss.cuda()
        mse_loss.cuda()

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # Configure training data loader
    dataloader = get_dataloader(dataset_name=dataset_name,
                                data_dir=data_dir,
                                batch_size=batch_size)

    # Test data loader
    testdata = get_dataloader(dataset_name=dataset_name,
                              data_dir=data_dir,
                              batch_size=test_batch_size,
                              train_set=False)
    test_imgs, test_labels = next(iter(testdata))
    test_imgs = Variable(test_imgs.type(Tensor))

    ge_chain = ichain(generator.parameters(), encoder.parameters())
    optimizer_GE = torch.optim.Adam(ge_chain,
                                    lr=lr,
                                    betas=(b1, b2),
                                    weight_decay=decay)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=lr,
                                   betas=(b1, b2))
    #optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2), weight_decay=decay)

    # ----------
    #  Training
    # ----------
    ge_l = []
    d_l = []

    c_zn = []
    c_zc = []
    c_i = []

    # Training loop
    print('\nBegin training session with %i epochs...\n' % (n_epochs))
    for epoch in range(n_epochs):
        for i, (imgs, itruth_label) in enumerate(dataloader):

            # Ensure generator/encoder are trainable
            generator.train()
            encoder.train()
            # Zero gradients for models
            generator.zero_grad()
            encoder.zero_grad()
            discriminator.zero_grad()

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))

            # ---------------------------
            #  Train Generator + Encoder
            # ---------------------------

            optimizer_GE.zero_grad()

            # Sample random latent variables
            zn, zc, zc_idx = sample_z(shape=imgs.shape[0],
                                      latent_dim=latent_dim,
                                      n_c=n_c)

            # Generate a batch of images
            gen_imgs = generator(zn, zc)

            # Discriminator output from real and generated samples
            D_gen = discriminator(gen_imgs)
            D_real = discriminator(real_imgs)

            # Step for Generator & Encoder, n_skip_iter times less than for discriminator
            if (i % n_skip_iter == 0):
                # Encode the generated images
                enc_gen_zn, enc_gen_zc, enc_gen_zc_logits = encoder(gen_imgs)

                # Calculate losses for z_n, z_c
                zn_loss = mse_loss(enc_gen_zn, zn)
                zc_loss = xe_loss(enc_gen_zc_logits, zc_idx)
                #zc_loss = cross_entropy(enc_gen_zc_logits, zc)

                # Check requested metric
                if wass_metric:
                    # Wasserstein GAN loss
                    ge_loss = torch.mean(
                        D_gen) + betan * zn_loss + betac * zc_loss
                else:
                    # Vanilla GAN loss
                    valid = Variable(Tensor(gen_imgs.size(0), 1).fill_(1.0),
                                     requires_grad=False)
                    v_loss = bce_loss(D_gen, valid)
                    ge_loss = v_loss + betan * zn_loss + betac * zc_loss

                ge_loss.backward(retain_graph=True)
                optimizer_GE.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            if wass_metric:
                # Gradient penalty term
                grad_penalty = calc_gradient_penalty(discriminator, real_imgs,
                                                     gen_imgs)

                # Wasserstein GAN loss w/gradient penalty
                d_loss = torch.mean(D_real) - torch.mean(D_gen) + grad_penalty

            else:
                # Vanilla GAN loss
                fake = Variable(Tensor(gen_imgs.size(0), 1).fill_(0.0),
                                requires_grad=False)
                real_loss = bce_loss(D_real, valid)
                fake_loss = bce_loss(D_gen, fake)
                d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

        # Save training losses
        d_l.append(d_loss.item())
        ge_l.append(ge_loss.item())

        # Generator in eval mode
        generator.eval()
        encoder.eval()

        # Set number of examples for cycle calcs
        n_sqrt_samp = 5
        n_samp = n_sqrt_samp * n_sqrt_samp

        ## Cycle through test real -> enc -> gen
        t_imgs, t_label = test_imgs.data, test_labels
        #r_imgs, i_label = real_imgs.data[:n_samp], itruth_label[:n_samp]
        # Encode sample real instances
        e_tzn, e_tzc, e_tzc_logits = encoder(t_imgs)
        # Generate sample instances from encoding
        teg_imgs = generator(e_tzn, e_tzc)
        # Calculate cycle reconstruction loss
        img_mse_loss = mse_loss(t_imgs, teg_imgs)
        # Save img reco cycle loss
        c_i.append(img_mse_loss.item())

        ## Cycle through randomly sampled encoding -> generator -> encoder
        zn_samp, zc_samp, zc_samp_idx = sample_z(shape=n_samp,
                                                 latent_dim=latent_dim,
                                                 n_c=n_c)
        # Generate sample instances
        gen_imgs_samp = generator(zn_samp, zc_samp)
        # Encode sample instances
        zn_e, zc_e, zc_e_logits = encoder(gen_imgs_samp)
        # Calculate cycle latent losses
        lat_mse_loss = mse_loss(zn_e, zn_samp)
        lat_xe_loss = xe_loss(zc_e_logits, zc_samp_idx)
        #lat_xe_loss = cross_entropy(zc_e_logits, zc_samp)
        # Save latent space cycle losses
        c_zn.append(lat_mse_loss.item())
        c_zc.append(lat_xe_loss.item())

        # Save cycled and generated examples!
        r_imgs, i_label = real_imgs.data[:n_samp], itruth_label[:n_samp]
        e_zn, e_zc, e_zc_logits = encoder(r_imgs)
        reg_imgs = generator(e_zn, e_zc)
        save_image(r_imgs.data[:n_samp],
                   '%s/real_%06i.png' % (imgs_dir, epoch),
                   nrow=n_sqrt_samp,
                   normalize=True)
        save_image(reg_imgs.data[:n_samp],
                   '%s/reg_%06i.png' % (imgs_dir, epoch),
                   nrow=n_sqrt_samp,
                   normalize=True)
        save_image(gen_imgs_samp.data[:n_samp],
                   '%s/gen_%06i.png' % (imgs_dir, epoch),
                   nrow=n_sqrt_samp,
                   normalize=True)

        ## Generate samples for specified classes
        stack_imgs = []
        for idx in range(n_c):
            # Sample specific class
            zn_samp, zc_samp, zc_samp_idx = sample_z(shape=n_c,
                                                     latent_dim=latent_dim,
                                                     n_c=n_c,
                                                     fix_class=idx)

            # Generate sample instances
            gen_imgs_samp = generator(zn_samp, zc_samp)

            if (len(stack_imgs) == 0):
                stack_imgs = gen_imgs_samp
            else:
                stack_imgs = torch.cat((stack_imgs, gen_imgs_samp), 0)

        # Save class-specified generated examples!
        save_image(stack_imgs,
                   '%s/gen_classes_%06i.png' % (imgs_dir, epoch),
                   nrow=n_c,
                   normalize=True)


        print ("[Epoch %d/%d] \n"\
               "\tModel Losses: [D: %f] [GE: %f]" % (epoch,
                                                     n_epochs,
                                                     d_loss.item(),
                                                     ge_loss.item())
              )

        print("\tCycle Losses: [x: %f] [z_n: %f] [z_c: %f]" %
              (img_mse_loss.item(), lat_mse_loss.item(), lat_xe_loss.item()))

    # Save training results
    train_df = pd.DataFrame({
        'n_epochs': n_epochs,
        'learning_rate': lr,
        'beta_1': b1,
        'beta_2': b2,
        'weight_decay': decay,
        'n_skip_iter': n_skip_iter,
        'latent_dim': latent_dim,
        'n_classes': n_c,
        'beta_n': betan,
        'beta_c': betac,
        'wass_metric': wass_metric,
        'gen_enc_loss': ['G+E', ge_l],
        'disc_loss': ['D', d_l],
        'zn_cycle_loss': ['$||Z_n-E(G(x))_n||$', c_zn],
        'zc_cycle_loss': ['$||Z_c-E(G(x))_c||$', c_zc],
        'img_cycle_loss': ['$||X-G(E(x))||$', c_i]
    })

    train_df.to_csv('%s/training_details.csv' % (run_dir))

    # Plot some training results
    plot_train_loss(df=train_df,
                    arr_list=['gen_enc_loss', 'disc_loss'],
                    figname='%s/training_model_losses.png' % (run_dir))

    plot_train_loss(
        df=train_df,
        arr_list=['zn_cycle_loss', 'zc_cycle_loss', 'img_cycle_loss'],
        figname='%s/training_cycle_loss.png' % (run_dir))

    # Save current state of trained models
    model_list = [discriminator, encoder, generator]
    save_model(models=model_list, out_dir=models_dir)