예제 #1
0
def main():
    parser = options.get_all_parser()
    args = options.parse_args_and_arch(parser)
    #print(args)

    # Saving test images
    os.makedirs(args.save_test_images, exist_ok=True)

    model = models.build_gan(args)
    """
    model = model.load_from_metrics(
                weights_path=args.restore_checkpoint,
                tags_csv=args.meta_tags,
                map_location=None)
    """
    model = model.load_from_checkpoint(checkpoint_path=args.restore_checkpoint)
    trainer = pl.Trainer(gpus=args.num_gpu)
    trainer.test(model)
예제 #2
0
def main():
    #parser = options.get_training_parser()
    parser = options.get_all_parser()
    args = options.parse_args_and_arch(parser)
    torch.manual_seed(args.seed)
   
    # Saving log and images
    os.makedirs(args.save_log_dir, exist_ok=True)
    #os.makedirs(args.save_valid_images, exist_ok=True)

    model = models.build_gan(args)
    trainer = pl.Trainer(show_progress_bar=args.progress_bar,
                         checkpoint_callback=save_load_checkpoint(args),
                         early_stop_callback=early_stop(args),
                         default_save_path=args.save_log_dir,
                         gpus=args.num_gpu, distributed_backend='dp',
                         train_percent_check=args.train_subset_split,
                         accumulate_grad_batches=args.update_freq,
                        )
    trainer.fit(model)
from keras.callbacks import TensorBoard
from keras.callbacks import Callback
from skimage.io import imsave
from tqdm import tqdm

use_data_augmentation = not args.no_augmentation

BS = args.batch_size
EPOCHS = args.epochs
w, h, c = args.width, args.height, args.channels
latent_dim = args.z_dim
D_ITER = 5
generator_model, discriminator_model, decoder, discriminator = build_gan(
    h=h,
    w=w,
    c=c,
    latent_dim=latent_dim,
    epsilon_std=args.std,
    dropout_rate=0.2)

train_generator = data_generator(args.dataset,
                                 height=h,
                                 width=w,
                                 channel=c,
                                 batch_size=BS,
                                 shuffle=True,
                                 normalize=not use_data_augmentation)
seq = get_imgaug()

if args.load_weights:
    decoder.load_weights('./decoder.h5')
예제 #4
0
파일: train.py 프로젝트: rahhul/GANs
        d2_hist.append(d_loss2)
        g_hist.append(g_loss)
        # evaluate
        if (i+1) % (batch_per_epoch * 1) == 0:
            log_performance(i, g_model, latent_dim)
    # plot
    plot_history(d1_hist, d2_hist, g_hist)



# EXAMPLE

latent_dim = 100

# discriminator model
discriminator = build_discriminator(in_shape=(28, 28, 1))

# generator model
generator = build_generator(latent_dim=latent_dim)

# gan model
gan_model = build_gan(generator, discriminator)

# image dataset
dataset = load_mnist()
print(dataset.shape)

# train

train(generator, discriminator, gan_model, dataset, latent_dim)
예제 #5
0
파일: train.py 프로젝트: sam1902/IMDB-GAN
def main():
    np.random.seed(0)

    # Quick run
    # vocab_size = 1000
    # max_len = 600
    # n_samples = 128

    vocab_size = 2000
    max_len = 1000
    n_samples = 128

    noise_size = 100
    data, labels = load_data(max_len=max_len,
                             vocab_size=vocab_size,
                             n_samples=n_samples)

    generative_optimizer = Adam(lr=1e-4)
    discriminative_optimizer = Adam(lr=1e-3)

    # In: (None, 100) <-- rnd noise
    # Out: (None, 1000, 2000) <-- sentence of (max_len) words encoded in (vocab_size)
    generative_model = build_generative(noise_size, max_len, vocab_size)
    generative_model.compile(loss='binary_crossentropy',
                             optimizer=generative_optimizer)
    # print(generative_model.summary())

    # In: (None, 1000, 2000) <-- sentence of (max_len) words encoded in (vocab_size)
    # Out: (None, 1) <-- probability of the sentence being real
    discriminative_model = build_discriminative(max_len, vocab_size)
    discriminative_model.compile(loss='binary_crossentropy',
                                 optimizer=discriminative_optimizer)
    # print(discriminative_model.summary())

    # Stacked GAN
    # In: (None, 100) <-- rnd noise
    # Out: (None, 1) <-- probability of the sentence being real
    gan = build_gan(noise_size, discriminative_model, generative_model)
    gan.compile(loss='binary_crossentropy', optimizer=generative_optimizer)
    # print(gan.summary())

    # -- Training the discriminator alone
    print('=' * 10 + 'Training discriminative model' + '=' * 10)

    print('-' * 10 + 'Building training data' + '-' * 10)
    training_samples, training_outputs = generate_mixed_data(
        data.train,
        generative_model,
        noise_size=noise_size,
        vocab_size=vocab_size,
        real_samples_size=100,
        generated_samples_size=100)
    print('Training samples shape: ', training_samples.shape)
    print('Training outputs shape: ', training_outputs.shape)

    print('-' * 10 + 'Building testing data' + '-' * 10)
    testing_samples, testing_outputs = generate_mixed_data(
        data.test,
        generative_model,
        noise_size=noise_size,
        vocab_size=vocab_size,
        real_samples_size=100,
        generated_samples_size=100)
    print('Testing samples shape: ', testing_samples.shape)
    print('Testing outputs shape: ', testing_outputs.shape)

    print('-' * 10 + 'Running the training process' + '-' * 10)
    make_trainable(discriminative_model, True)
    discriminative_model.fit(training_samples,
                             training_outputs,
                             batch_size=128,
                             epochs=2)

    print('-' * 10 + 'Evaluating the discriminative model' + '-' * 10)
    scores = discriminative_model.evaluate(testing_samples, testing_outputs)
    print('Loss on testing samples: {:.2%}'.format(scores))

    losses = {"d": [], "g": []}

    try:
        change_lr(gan, 1e-4)
        change_lr(discriminative_model, 1e-3)

        losses = train(nb_epochs=6000,
                       batch_size=32,
                       training_data=data.train,
                       discriminative_model=discriminative_model,
                       generative_model=generative_model,
                       gan_model=gan,
                       noise_size=noise_size,
                       vocab_size=vocab_size,
                       losses=losses)
        export('1', losses, discriminative_model, generative_model, gan)

        change_lr(gan, 1e-5)
        change_lr(discriminative_model, 1e-4)

        losses = train(nb_epochs=2000,
                       batch_size=32,
                       training_data=data.train,
                       discriminative_model=discriminative_model,
                       generative_model=generative_model,
                       gan_model=gan,
                       noise_size=noise_size,
                       vocab_size=vocab_size,
                       losses=losses)
        export('2', losses, discriminative_model, generative_model, gan)

        change_lr(gan, 1e-6)
        change_lr(discriminative_model, 1e-5)

        losses = train(nb_epochs=2000,
                       batch_size=32,
                       training_data=data.train,
                       discriminative_model=discriminative_model,
                       generative_model=generative_model,
                       gan_model=gan,
                       noise_size=noise_size,
                       vocab_size=vocab_size,
                       losses=losses)
        export('3', losses, discriminative_model, generative_model, gan)

    except KeyboardInterrupt as _:
        export('quitedInBetween', losses, discriminative_model,
               generative_model, gan)