Exemplo n.º 1
0
def train():
    div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
    div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

    train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
    valid_ds = div2k_valid.dataset(batch_size=16,
                                   random_transform=True,
                                   repeat_count=1)

    pre_trainer = SrganGeneratorTrainer(model=generator(),
                                        checkpoint_dir='.ckpt/pre_generator')
    pre_trainer.train(train_ds,
                      valid_ds.take(10),
                      steps=1000000,
                      evaluate_every=1000,
                      save_best_only=False)
    pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

    gan_generator = generator()
    gan_generator.load_weights(weights_file('pre_generator.h5'))

    gan_trainer = SrganTrainer(generator=gan_generator,
                               discriminator=discriminator())
    gan_trainer.train(train_ds, steps=200000)

    gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
    gan_trainer.discriminator.save_weights(
        weights_file('gan_discriminator.h5'))
Exemplo n.º 2
0
                  valid_ds.take(100),
                  steps=300000,
                  evaluate_every=1000,
                  save_best_only=True)

    # Restore from checkpoint with highest PSNR
    trainer.restore()

    # Evaluate model on full validation set
    psnr = trainer.evaluate(valid_ds)
    print(f'PSNR = {psnr.numpy():3f}')

    # Save weights to separate location
    trainer.model.save_weights(weights_file)

    # Custom WDSR B model (0.62M parameters)
    generator = wdsr_b(scale=4, num_res_blocks=32)
    generator.load_weights('weights/wdsr/weights.h5')

    train_ds_small_batch = catesr_train.dataset(batch_size=1,
                                                random_transform=True,
                                                shuffle_buffer_size=500)

    # Fine-tune EDSR model via SRGAN training.
    gan_trainer = SrganTrainer(generator=generator,
                               discriminator=discriminator())
    gan_trainer.train(train_ds_small_batch, steps=200000)

    new_weights_file = os.path.join(weights_dir,
                                    'weights_fine_tuned_200000_steps.h5')
    generator.save_weights(new_weights_file)
Exemplo n.º 3
0
    def _generator_loss(self, sr_out):
        return self.binary_cross_entropy(tf.ones_like(sr_out), sr_out)

    def _discriminator_loss(self, hr_out, sr_out):
        hr_loss = self.binary_cross_entropy(tf.ones_like(hr_out), hr_out)
        sr_loss = self.binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
        return hr_loss + sr_loss


div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)

#To pretrain gen
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,valid_ds.take(10),steps=50000,evaluate_every=1000,save_best_only=False)

CWD_PATH = os.getcwd()

#To train gan
    gan_generator = generator()
    gan_generator.load_weights(os.path.join(CWD_PATH,'weights','pre_generator.h5'))

gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=50000)

gan_trainer.generator.save_weights(os.path.join(CWD_PATH,'weights','gan_generator.h5'))
gan_trainer.discriminator.save_weights(os.path.join(CWD_PATH,'weights','gan_discriminator.h5'))
Exemplo n.º 4
0
def main(args):
    train_dir, models_dir = create_train_workspace(args.outdir)
    losses_file = os.path.join(train_dir, 'losses.csv')
    write_args(train_dir, args)
    logger.info('Training workspace is %s', train_dir)

    sequence = DIV2KSequence(args.dataset,
                             scale=args.scale,
                             subset='train',
                             downgrade=args.downgrade,
                             image_ids=range(1,801),
                             batch_size=args.batch_size,
                             crop_size=96)

    if args.generator == 'edsr-gen':
        generator = edsr.edsr_generator(args.scale, args.num_filters, args.num_res_blocks)
    else:
        generator = srgan.generator(args.num_filters, args.num_res_blocks)

    if args.pretrained_model:
        generator.load_weights(args.pretrained_model)

    generator_optimizer = Adam(lr=args.generator_learning_rate)

    discriminator = srgan.discriminator()
    discriminator_optimizer = Adam(lr=args.discriminator_learning_rate)
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=discriminator_optimizer,
                          metrics=[])

    gan = srgan.srgan(generator, discriminator)
    gan.compile(loss=[content_loss, 'binary_crossentropy'],
                loss_weights=[0.006, 0.001],
                optimizer=generator_optimizer,
                metrics=[])

    generator_lr_scheduler = learning_rate(step_size=args.learning_rate_step_size, decay=args.learning_rate_decay, verbose=0)
    generator_lr_scheduler.set_model(gan)

    discriminator_lr_scheduler = learning_rate(step_size=args.learning_rate_step_size, decay=args.learning_rate_decay, verbose=0)
    discriminator_lr_scheduler.set_model(discriminator)

    with open(losses_file, 'w') as f:
        f.write('Epoch,Discriminator loss,Generator loss\n')

    with concurrent_generator(sequence, num_workers=1) as gen:
        for epoch in range(args.epochs):

            generator_lr_scheduler.on_epoch_begin(epoch)
            discriminator_lr_scheduler.on_epoch_begin(epoch)

            d_losses = []
            g_losses_0 = []
            g_losses_1 = []
            g_losses_2 = []

            for iteration in range(args.iterations_per_epoch):

                # ----------------------
                #  Train Discriminator
                # ----------------------

                lr, hr = next(gen)
                sr = generator.predict(lr)

                hr_labels = np.ones(args.batch_size) + args.label_noise * np.random.random(args.batch_size)
                sr_labels = np.zeros(args.batch_size) + args.label_noise * np.random.random(args.batch_size)

                hr_loss = discriminator.train_on_batch(hr, hr_labels)
                sr_loss = discriminator.train_on_batch(sr, sr_labels)

                d_losses.append((hr_loss + sr_loss) / 2)

                # ------------------
                #  Train Generator
                # ------------------

                lr, hr = next(gen)

                labels = np.ones(args.batch_size)

                perceptual_loss = gan.train_on_batch(lr, [hr, labels])

                g_losses_0.append(perceptual_loss[0])
                g_losses_1.append(perceptual_loss[1])
                g_losses_2.append(perceptual_loss[2])

                print(f'[{epoch:03d}-{iteration:03d}] '
                      f'discriminator loss = {np.mean(d_losses[-50:]):.3f} '
                      f'generator loss = {np.mean(g_losses_0[-50:]):.3f} ('
                      f'mse = {np.mean(g_losses_1[-50:]):.3f} '
                      f'bxe = {np.mean(g_losses_2[-50:]):.3f})')

            generator_lr_scheduler.on_epoch_end(epoch)
            discriminator_lr_scheduler.on_epoch_end(epoch)

            with open(losses_file, 'a') as f:
                f.write(f'{epoch},{np.mean(d_losses)},{np.mean(g_losses_0)}\n')

            model_path = os.path.join(models_dir, f'generator-epoch-{epoch:03d}.h5')
            print('Saving model', model_path)
            generator.save(model_path)