示例#1
0
def train():
    div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
    div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

    train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
    valid_ds = div2k_valid.dataset(batch_size=16,
                                   random_transform=True,
                                   repeat_count=1)

    pre_trainer = SrganGeneratorTrainer(model=generator(),
                                        checkpoint_dir='.ckpt/pre_generator')
    pre_trainer.train(train_ds,
                      valid_ds.take(10),
                      steps=1000000,
                      evaluate_every=1000,
                      save_best_only=False)
    pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

    gan_generator = generator()
    gan_generator.load_weights(weights_file('pre_generator.h5'))

    gan_trainer = SrganTrainer(generator=gan_generator,
                               discriminator=discriminator())
    gan_trainer.train(train_ds, steps=200000)

    gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
    gan_trainer.discriminator.save_weights(
        weights_file('gan_discriminator.h5'))
示例#2
0
def main():
    if sys.argv[1]:
        LR_IMG_LOCATION = sys.argv[1]

    if sys.argv[2]:
        HR_IMG_LOCATION = sys.argv[2]

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
                logical_gpus = tf.config.experimental.list_logical_devices(
                    'GPU')
                print(len(gpus), "Physical GPUs,", len(logical_gpus),
                      "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    # SRGAN
    srgan_model = generator()
    srgan_model.load_weights('weights/srgan/gan_generator.h5')

    low_res = cv2.imread(LR_IMG_LOCATION)
    high_res = resolve_single(srgan_model, low_res).numpy()

    cv2.imwrite(HR_IMG_LOCATION, high_res, [cv2.IMWRITE_PNG_COMPRESSION, 9])
示例#3
0
def main():
    scale = 30
    if sys.argv[1]:
        scale = int(sys.argv[1])

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
                logical_gpus = tf.config.experimental.list_logical_devices(
                    'GPU')
                print(len(gpus), "Physical GPUs,", len(logical_gpus),
                      "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    # SRGAN
    srgan_model = generator()
    srgan_model.load_weights('weights/srgan/gan_generator.h5')

    cap = cv2.VideoCapture(0)
    ret, frame = cap.read()

    while (True):
        # Get dimentions of cropped image.
        width = int(frame.shape[1] * (scale / 100))
        height = int(frame.shape[0] * (scale / 100))
        dim = (width, height)

        # Capture frame-by-frame
        ret, frame = cap.read()

        # Crop image by scale
        crop = cv2.resize(frame, dim, interpolation=cv2.INTER_AREA)

        # Upscale image
        upscale = resolve_single(srgan_model, crop).numpy()

        # Resize crop and upscaled image to be the same as input image.
        width = int(frame.shape[1])
        height = int(frame.shape[0])
        dim = (width, height)
        crop = cv2.resize(crop, dim, interpolation=cv2.INTER_AREA)
        upscale = cv2.resize(upscale, dim, interpolation=cv2.INTER_AREA)

        # Display the Resulting Frames
        cv2.imshow('input', frame)
        cv2.imshow('cropped', crop)
        cv2.imshow('srgan', upscale)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # When everything done, release the capture
    cap.release()
    cv2.destroyAllWindows()
def main():
    try:
        LR_VID_LOCATION = sys.argv[1]
    except AssertionError as e:
        # Low resolution video path must be provided in first argument.
        print(e)
    try:
        HR_VID_LOCATION = sys.argv[2]
    except AssertionError as e:
        # High resolution video path must be be provided in second argument
        print(e)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
                logical_gpus = tf.config.experimental.list_logical_devices(
                    'GPU')
                print(len(gpus), "Physical GPUs,", len(logical_gpus),
                      "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    # SRGAN
    srgan_model = generator()
    srgan_model.load_weights('weights/srgan/gan_generator.h5')

    cap = cv2.VideoCapture(LR_VID_LOCATION)

    success, frame = cap.read()
    if success == True:
        upscale = resolve_single(srgan_model, frame)
        width = int(upscale.shape[1])
        height = int(upscale.shape[0])
        dim = (width, height)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        fps = cap.get(cv2.CAP_PROP_FPS)
        out = cv2.VideoWriter(HR_VID_LOCATION, fourcc, fps, dim)

        frame = 0
        while success == True:
            upscale = resolve_single(srgan_model, frame).numpy()
            out.write(upscale)
            success, frame = cap.read()
            print("Buffering Frame:", frame)
            frame += 1

    cap.release()
    out.release()
示例#5
0
def get_generator(model_arc, is_train=True):
    if model_arc == 'srfeat':
        model = srfeat.generator(is_train=is_train)
    elif model_arc == 'srgan':
        model = srgan.generator(is_train=is_train)
    elif model_arc == 'esrgan':
        model = esrgan.generator()
    elif model_arc == 'edsr':
        model = edsr.generator()
    elif model_arc == 'rcan':
        model = rcan.generator()
    elif model_arc == 'erca':
        model = erca.generator()
    elif model_arc == 'gan':
        model = srfeat.generator(is_train=is_train, use_bn=False)
    else:
        raise Exception(
            'Wrong model architecture! It should be srfeat, argan, esrgan, edsr, rcan or erca.'
        )
    return model
示例#6
0
                          images_dir='/home/ec2-user/gans/data/images_rgb',
                          caches_dir='/home/ec2-user/gans/data/caches_rgb')
    catesr_valid = CATESR(subset='valid',
                          images_dir='/home/ec2-user/gans/data/images_rgb',
                          caches_dir='/home/ec2-user/gans/data/caches_rgb')

    train_ds = catesr_train.dataset(batch_size=1,
                                    random_transform=True,
                                    shuffle_buffer_size=500)
    valid_ds = catesr_valid.dataset(batch_size=1,
                                    random_transform=False,
                                    repeat_count=1)

    # First train the generator

    generator_model = generator()
    generator_model.load_weights(
        os.path.join(weights_dir, 'pretrained_gan_generator.h5'))

    pre_trainer = SrganGeneratorTrainer(model=generator_model,
                                        checkpoint_dir='.ckpt/pre_generator')
    pre_trainer.train(train_ds,
                      valid_ds.take(100),
                      steps=100000,
                      evaluate_every=1000,
                      save_best_only=True)

    pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

    # Generator fine - tuning(GAN)
    gan_generator = generator()
示例#7
0
def Use_SRGAN():
    model = generator()
    model.load_weights('weights/srgan/gan_generator.h5')
    lr = Load_Image(filename)
    sr = resolve_single(model, lr)
    Save_Image(sr)
示例#8
0
from train import SrganGeneratorTrainer

from data import DIV2K

train_loader = DIV2K(scale=4, downgrade='bicubic', subset='train')

train_ds = train_loader.dataset(batch_size=16,
                                random_transform=True,
                                repeat_count=None)
valid_loader = DIV2K(scale=4, downgrade='bicubic', subset='valid')

valid_ds = valid_loader.dataset(batch_size=1,
                                random_transform=False,
                                repeat_count=1)

pre_trainer = SrganGeneratorTrainer(model=generator(num_res_blocks=6),
                                    checkpoint_dir=f'.ckpt/pre_generator')

pre_trainer.train(train_ds,
                  valid_ds.take(10),
                  steps=1000000,
                  evaluate_every=1000)

pre_trainer.model.save_weights('weights/srgan/pre_generator_6.h5')

from model.srgan import generator, discriminator
from train import SrganTrainer

gan_generator = generator(num_res_blocks=8)
gan_generator.load_weights('weights/srgan/pre_generator_6.h5')
示例#9
0
"""Upload images, poses, signatures"""
poses = data.read_pose('./data/pose.pkl')
signatures = data.read_signatures('./data/signatures.pkl')
with Images('data/images.tar') as images:
    path = images.paths[20000]
    image = images._getitem(path)
    print ('read image {} of shape {}'.format(path, image.shape))


my_split=poses[0]
my_split=[path[:-4] for path in my_split]


"""Use SRGAN"""
srgan = generator()
srgan.load_weights('weights/srgan/gan_generator.h5')


"""Upload customed cnn model"""
cnn = CNN(256, 256, 3, 101)
cnn.load_weights('weights/custom/cnn_plus.h5')
plot_model(cnn, to_file='./model.png', show_shapes=True, show_layer_names=True)


train_model(2, 'cnn_plus', cnn, srgan)

#filepath="./cnn_weights.h5"
#checkpoint = ModelCheckpoint(filepath, monitor='accuracy', verbose=1, save_best_only=True, mode='max')
#callbacks_list = [checkpoint]
def main(args):
    train_dir, models_dir = create_train_workspace(args.outdir)
    write_args(train_dir, args)
    logger.info('Training workspace is %s', train_dir)

    training_generator = cropped_sequence(args.dataset,
                                          scale=args.scale,
                                          subset='train',
                                          downgrade=args.downgrade,
                                          image_ids=args.training_images,
                                          batch_size=args.batch_size)

    if args.benchmark:
        logger.info(
            'Validation with full-size images from DIV2K validation set')
        validation_steps = len(args.validation_images)
        validation_generator = fullsize_sequence(
            args.dataset,
            scale=args.scale,
            subset='valid',
            downgrade=args.downgrade,
            image_ids=args.validation_images)
    else:
        logger.info(
            'Validation with randomly cropped images from DIV2K validation set'
        )
        validation_steps = args.validation_steps
        validation_generator = cropped_sequence(
            args.dataset,
            scale=args.scale,
            subset='valid',
            downgrade=args.downgrade,
            image_ids=args.validation_images,
            batch_size=args.batch_size)

    if args.initial_epoch:
        logger.info('Resume training of model %s', args.pretrained_model)
        model = _load_model(args.pretrained_model)

    else:
        if args.model == "sr-resnet":
            #
            # Pre-training of SRResNet-based generator
            # (for usage in SRGAN)
            #

            model = srgan.generator(num_filters=args.num_filters,
                                    num_res_blocks=args.num_res_blocks,
                                    pred_logvar=args.pred_logvar)
            if not args.pred_logvar:
                loss = mean_squared_error
            else:
                loss = heteroscedastic_loss(args.attention,
                                            args.block_attention_gradient,
                                            mode='l2')
        elif args.model == "edsr-gen":
            #
            # Pre-training of EDSR-based generator
            # (for usage in an SRGAN-like network)
            #
            loss = mean_squared_error
            model = edsr.edsr_generator(scale=args.scale,
                                        num_filters=args.num_filters,
                                        num_res_blocks=args.num_res_blocks)
        elif args.model == "edsr":
            loss = mean_absolute_error
            model = edsr.edsr(scale=args.scale,
                              num_filters=args.num_filters,
                              num_res_blocks=args.num_res_blocks,
                              res_block_scaling=args.res_scaling)
        else:
            loss = mae
            model_fn = wdsr.wdsr_b if args.model == 'wdsr-b' else wdsr.wdsr_a
            model = model_fn(scale=args.scale,
                             num_filters=args.num_filters,
                             num_res_blocks=args.num_res_blocks,
                             res_block_expansion=args.res_expansion,
                             res_block_scaling=args.res_scaling)

        if args.weightnorm:
            model.compile(
                optimizer=wn.AdamWithWeightnorm(lr=args.learning_rate),
                loss=loss,
                metrics=[psnr])
            if args.num_init_batches > 0:
                logger.info(
                    'Data-based initialization of weights with %d batches',
                    args.num_init_batches)
                model_weightnorm_init(model, training_generator,
                                      args.num_init_batches)
        else:
            model.compile(
                optimizer=Adam(lr=args.learning_rate),
                loss=loss,
                metrics=[psnr] if not args.pred_logvar else [psnr_unc, mse])

        if args.pretrained_model:
            logger.info(
                'Initialization with weights from pre-trained model %s',
                args.pretrained_model)
            copy_weights(from_model=_load_model(args.pretrained_model),
                         to_model=model)

    if args.print_model_summary:
        model.summary()

    callbacks = [
        tensor_board(train_dir),
        learning_rate(step_size=args.learning_rate_step_size,
                      decay=args.learning_rate_decay),
        model_checkpoint_after(
            args.save_models_after_epoch,
            models_dir,
            monitor=f'val_psnr' if not args.pred_logvar else f'val_psnr_unc',
            save_best_only=args.save_best_models_only or args.benchmark)
    ]

    model.fit_generator(training_generator,
                        epochs=args.epochs,
                        initial_epoch=args.initial_epoch,
                        steps_per_epoch=args.iterations_per_epoch,
                        validation_data=validation_generator,
                        validation_steps=validation_steps,
                        use_multiprocessing=args.use_multiprocessing,
                        max_queue_size=args.max_queue_size,
                        workers=args.num_workers,
                        callbacks=callbacks)
示例#11
0
def main(args):
    train_dir, models_dir = create_train_workspace(args.outdir)
    losses_file = os.path.join(train_dir, 'losses.csv')
    write_args(train_dir, args)
    logger.info('Training workspace is %s', train_dir)

    sequence = DIV2KSequence(args.dataset,
                             scale=args.scale,
                             subset='train',
                             downgrade=args.downgrade,
                             image_ids=range(1,801),
                             batch_size=args.batch_size,
                             crop_size=96)

    if args.generator == 'edsr-gen':
        generator = edsr.edsr_generator(args.scale, args.num_filters, args.num_res_blocks)
    else:
        generator = srgan.generator(args.num_filters, args.num_res_blocks)

    if args.pretrained_model:
        generator.load_weights(args.pretrained_model)

    generator_optimizer = Adam(lr=args.generator_learning_rate)

    discriminator = srgan.discriminator()
    discriminator_optimizer = Adam(lr=args.discriminator_learning_rate)
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=discriminator_optimizer,
                          metrics=[])

    gan = srgan.srgan(generator, discriminator)
    gan.compile(loss=[content_loss, 'binary_crossentropy'],
                loss_weights=[0.006, 0.001],
                optimizer=generator_optimizer,
                metrics=[])

    generator_lr_scheduler = learning_rate(step_size=args.learning_rate_step_size, decay=args.learning_rate_decay, verbose=0)
    generator_lr_scheduler.set_model(gan)

    discriminator_lr_scheduler = learning_rate(step_size=args.learning_rate_step_size, decay=args.learning_rate_decay, verbose=0)
    discriminator_lr_scheduler.set_model(discriminator)

    with open(losses_file, 'w') as f:
        f.write('Epoch,Discriminator loss,Generator loss\n')

    with concurrent_generator(sequence, num_workers=1) as gen:
        for epoch in range(args.epochs):

            generator_lr_scheduler.on_epoch_begin(epoch)
            discriminator_lr_scheduler.on_epoch_begin(epoch)

            d_losses = []
            g_losses_0 = []
            g_losses_1 = []
            g_losses_2 = []

            for iteration in range(args.iterations_per_epoch):

                # ----------------------
                #  Train Discriminator
                # ----------------------

                lr, hr = next(gen)
                sr = generator.predict(lr)

                hr_labels = np.ones(args.batch_size) + args.label_noise * np.random.random(args.batch_size)
                sr_labels = np.zeros(args.batch_size) + args.label_noise * np.random.random(args.batch_size)

                hr_loss = discriminator.train_on_batch(hr, hr_labels)
                sr_loss = discriminator.train_on_batch(sr, sr_labels)

                d_losses.append((hr_loss + sr_loss) / 2)

                # ------------------
                #  Train Generator
                # ------------------

                lr, hr = next(gen)

                labels = np.ones(args.batch_size)

                perceptual_loss = gan.train_on_batch(lr, [hr, labels])

                g_losses_0.append(perceptual_loss[0])
                g_losses_1.append(perceptual_loss[1])
                g_losses_2.append(perceptual_loss[2])

                print(f'[{epoch:03d}-{iteration:03d}] '
                      f'discriminator loss = {np.mean(d_losses[-50:]):.3f} '
                      f'generator loss = {np.mean(g_losses_0[-50:]):.3f} ('
                      f'mse = {np.mean(g_losses_1[-50:]):.3f} '
                      f'bxe = {np.mean(g_losses_2[-50:]):.3f})')

            generator_lr_scheduler.on_epoch_end(epoch)
            discriminator_lr_scheduler.on_epoch_end(epoch)

            with open(losses_file, 'a') as f:
                f.write(f'{epoch},{np.mean(d_losses)},{np.mean(g_losses_0)}\n')

            model_path = os.path.join(models_dir, f'generator-epoch-{epoch:03d}.h5')
            print('Saving model', model_path)
            generator.save(model_path)
示例#12
0
def run(args):
    args = parse_args(args)
    if args.image_path is not None:
        input_images = process_image(args.image_path, args.resolution)
        latentgan_model = None
    else:
        input_images = None
        print(
            "WARNING: no input image directory specified, embeddings will be sampled using Laten GAN"
        )
        latentgan_model = LatentGAN.load(args.latent_gan_model_path)
    confignet_model = ConfigNet.load(args.confignet_model_path)

    #basic_ui = BasicUI(confignet_model)

    # Sample latent embeddings from input images if available and if not sample from Latent GAN
    current_embedding_unmodified, current_rotation, orig_images = get_new_embeddings(
        input_images, latentgan_model, confignet_model)
    # Set next embedding value for rendering
    if args.enable_sr == 1:
        modelSR = generator()
        modelSR.load_weights('evaluation/weights/srgan/gan_generator.h5')

    yaw_min_angle = -args.max_angle
    pitch_min_angle = -args.max_angle
    yaw_max_angle = args.max_angle
    pitch_max_angle = args.max_angle
    delta_angle = 5

    rotation_offset = np.zeros((1, 3))

    eye_rotation_offset = np.zeros((1, 3))

    facemodel_param_names = list(
        confignet_model.config["facemodel_inputs"].keys())
    # remove eye rotation as in the demo it is controlled separately
    eye_rotation_param_idx = facemodel_param_names.index(
        "bone_rotations:left_eye")
    facemodel_param_names.pop(eye_rotation_param_idx)

    render_input_interp_0 = current_embedding_unmodified
    render_input_interp_1 = current_embedding_unmodified

    interpolation_coef = 0
    if not os.path.exists(dataset_directory):
        os.makedirs(dataset_directory)
    # This interpolates between the previous and next set embeddings
    current_renderer_input = render_input_interp_0 * (
        1 - interpolation_coef) + render_input_interp_1 * interpolation_coef
    # Set eye gaze direction as controlled by the user
    current_renderer_input = set_gaze_direction_in_embedding(
        current_renderer_input, eye_rotation_offset, confignet_model)

    # all angles
    #image = Image.open(args.image_path)
    #print(np.array(image))
    #return
    i = 1
    print('All angles')
    for yaw in range(yaw_min_angle, yaw_max_angle + 1, delta_angle):
        for pitch in range(pitch_min_angle, pitch_max_angle + 1, delta_angle):
            rotation_offset[0, 0] = to_rad(yaw)
            rotation_offset[0, 1] = to_rad(pitch)
            generated_imgs = confignet_model.generate_images(
                current_renderer_input, current_rotation + rotation_offset)
            if args.enable_sr == 1:
                img = cv2.resize(generated_imgs[0], (256, 256))
                sr_img = resolve_single(modelSR, img)
                cv2.imwrite(dataset_directory + '/%d_%d.png' % (yaw, pitch),
                            np.array(sr_img))
            else:
                img = cv2.resize(generated_imgs[0], (1024, 1024))
                cv2.imwrite(dataset_directory + '/%d_%d.png' % (yaw, pitch),
                            img)
            print(i)
            i += 1

    #all random
    # 100 картинок со случайными поворотами от -20 до 20, поворотами глаз, выражений лица
    print('All random')
    current_attribute_name = facemodel_param_names[1]  #blendshape_values
    frame_embedding = render_input_interp_0 * (
        1 - interpolation_coef) + render_input_interp_1 * interpolation_coef
    for i in range(100):
        eye_rotation_offset[0, 2] = to_rad(np.random.randint(-40, 40))
        eye_rotation_offset[0, 0] = to_rad(np.random.randint(-40, 40))
        rotation_offset[0, 0] = to_rad(np.random.randint(-20, 20))
        rotation_offset[0, 1] = to_rad(np.random.randint(-20, 20))
        frame_embedding = set_gaze_direction_in_embedding(
            frame_embedding, eye_rotation_offset, confignet_model)
        new_embedding_value = get_embedding_with_new_attribute_value(
            current_attribute_name, frame_embedding, confignet_model)

        generated_imgs = confignet_model.generate_images(
            new_embedding_value, current_rotation + rotation_offset)

        if args.enable_sr == 1:
            img = cv2.resize(generated_imgs[0], (256, 256))
            sr_img = resolve_single(modelSR, img)
            cv2.imwrite(dataset_directory + '/random_%d.png' % (i),
                        np.array(sr_img))
        else:
            img = cv2.resize(generated_imgs[0], (1024, 1024))
            cv2.imwrite(dataset_directory + '/random_%d.png' % (i), img)
        print(i)
 def __init__(self):
     self.gan_generator = generator()
     self.CWD_PATH = os.getcwd()
     self.gan_generator.load_weights(
         os.path.join(self.CWD_PATH, 'weights', 'gan_generator.h5'))
import tensorflow as tf
from imutils.video import VideoStream
import imutils
import cv2, os, urllib.request
import numpy as np
from django.conf import settings
from model.srgan import generator
from model.wdsr import wdsr_b
from model import resolve_single

# SRGAN
srgan_model = generator()
srgan_model.load_weights('weights/srgan/gan_generator.h5')

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)


class VideoCamera(object):
    def __init__(self, scale=20):
        self.video = cv2.VideoCapture(0)
    hr_images = glob(os.path.join(hr_dir, '*.png'))

    if len(lr_images) != len(hr_images):
        raise ValueError

    no_of_images = len(lr_images)

    print(no_of_images)

    for i in range(0, no_of_images):
        resolve_and_plot(lr_images[i], hr_images[i], out_dir, i)


if __name__ == "__main__":

    pre_generator = generator()
    gan_generator = generator()

    # Location of model weights
    weights_dir = 'weights/srgan'
    weights_file = lambda filename: os.path.join(weights_dir, filename)

    pre_generator.load_weights(weights_file('pre_generator.h5'))
    gan_generator.load_weights(weights_file('gan_generator.h5'))

    pre_generator.load_weights(
        '/Users/cate/git/remote-sensing-super-resolution/SRGAN/weights/srgan/pre_generator.h5'
    )
    gan_generator.load_weights(
        '/Users/cate/git/remote-sensing-super-resolution/SRGAN/weights/srgan/gan_generator.h5'
    )
示例#16
0
    def _generator_loss(self, sr_out):
        return self.binary_cross_entropy(tf.ones_like(sr_out), sr_out)

    def _discriminator_loss(self, hr_out, sr_out):
        hr_loss = self.binary_cross_entropy(tf.ones_like(hr_out), hr_out)
        sr_loss = self.binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
        return hr_loss + sr_loss


div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)

#To pretrain gen
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,valid_ds.take(10),steps=50000,evaluate_every=1000,save_best_only=False)

CWD_PATH = os.getcwd()

#To train gan
    gan_generator = generator()
    gan_generator.load_weights(os.path.join(CWD_PATH,'weights','pre_generator.h5'))

gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=50000)

gan_trainer.generator.save_weights(os.path.join(CWD_PATH,'weights','gan_generator.h5'))
gan_trainer.discriminator.save_weights(os.path.join(CWD_PATH,'weights','gan_discriminator.h5'))
示例#17
0
            upscale_image(image_path, sr_model)

        old_seed = new_seed

    os.system("ffmpeg -y -framerate 25 -i ../generated_video/frames/test_%03d.png ../generated_video/latentSpaceNavigation_0.18.mov")


def interpolate_points(p1, p2, n_steps=100):
    ratios = np.linspace(0, 1, num=n_steps)
    vectors = list()
    for ratio in ratios:
        v = (1.0 - ratio) * p1 + ratio * p2
        vectors.append(v)
    return np.asarray(vectors)


def upscale_image(filePath, model):
    image = np.asarray(Image.open(filePath).convert('RGB'))
    upscaled = resolve_single(model, image)
    upscaled_resized = Image.fromarray(np.asarray(upscaled)).resize((256, 256))
    upscaled_resized = np.asarray(upscaled_resized)
    upscaled_second = resolve_single(model, upscaled_resized)
    image_processed = Image.fromarray(np.asarray(upscaled_second)).convert('L')
    image_processed.save(filePath)


sr_model = generator()
sr_model.load_weights('../ThirdParty/super-resolution/weights/srgan/gan_generator.h5')
createVideo(1.8, 20, sr_model)
示例#18
0
# Load dataset
train_ds = tfds.load('sr_dataset',
                     split='train',
                     as_supervised=True,
                     shuffle_files=True)
valid_ds = tfds.load('sr_dataset',
                     split='test',
                     as_supervised=True,
                     shuffle_files=True)

train_ds = train_ds.shuffle(buffer_size=100).repeat(100).batch(4)
valid_ds = valid_ds.shuffle(buffer_size=100).repeat(100).batch(4)

# Generator pre-training

model = generator()
model.summary()
pre_trainer = SrganGeneratorTrainer(model=generator(),
                                    checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,
                  valid_ds.take(10),
                  steps=8000,
                  evaluate_every=80,
                  save_best_only=False)

pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

# Generator fine-tuning (GAN)

#gan_generator = generator()
#gan_generator.load_weights(weights_file('pre_generator.h5'))
示例#19
0
from model import resolve_single
from utils import load_image
import tensorflow as tf
import cv2
from model.srgan import generator, discriminator
from model import resolve_single
from utils import load_image
import PIL
from PIL import Image
from tensorflow.python.types import core as core_tf_types

weights_dir = 'weights/srgan'
weights_file = lambda filename: os.path.join(weights_dir, filename)
os.makedirs(weights_dir, exist_ok=True)

gan_generator = generator()
gan_generator.load_weights(weights_file('gan_generator.h5'))

frame_number = 0
folder = r"C:\Users\Hari\Desktop\super-resolution-master\data"
current_path = "C:\\Users\\Hari\\Desktop\\super-resolution-master\\save\\"

for filename in os.listdir(folder):
    frame_number += 1
    print(frame_number)
    lr = load_image(os.path.join(folder, filename))
    gan_sr = resolve_single(gan_generator, lr)
    tf.keras.preprocessing.image.save_img(
        current_path + str(frame_number) + ".png", gan_sr)
    #save = cv2.resize(lr, (100,100))
    #cv2.imwrite(current_path + str(frame_number) + ".png",save)
示例#20
0
def main(args):
    # configuration for efficient use of gpu
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    print('Loading srgan generator...')
    gen_graph = tf.Graph()
    with gen_graph.as_default():
        gen_sess = tf.Session(config=config, graph=gen_graph)
        with gen_sess.as_default():
            srgan_generator = generator()
            srgan_generator.load_weights(args.weights_path)

    # sort files by filenames, assuming names match in both paths
    lr_images_files = sorted(
        glob.glob(os.path.join(args.images_path, '*LR.png')))
    hr_images_files = sorted(
        glob.glob(os.path.join(args.images_path, '*HR.png')))

    # check if number of images align
    if len(lr_images_files) != len(hr_images_files):
        raise RuntimeError('length of image files doesn`t match,'
                           'need same number of images for both'
                           'low resolution and high resolution!')

    image_files = (lr_images_files, hr_images_files)

    # two list of metrics on all images
    psnr_vals, ssim_vals = evaluate_session(gen_sess, image_files,
                                            srgan_generator.input.name,
                                            srgan_generator.output.name)
    psnr_val = np.mean(psnr_vals)
    ssim_val = np.mean(ssim_vals)
    print(
        f'Mean PSNR and SSIM for given images on original model are: [{psnr_val}, {ssim_val}]'
    )

    # TODO: use a better default dataset for compute encodings when not given by users
    # use low resolution images if no representative lr data are provided

    # use low and high resolution images if no representative lr and hr data are provided
    if args.representative_datapath:
        bc_lr_data = glob.glob(
            os.path.join(args.representative_datapath, '*LR.png'))
        comp_encodings_lr_data = glob.glob(
            os.path.join(args.representative_datapath, '*LR.png'))
        comp_encodings_hr_data = glob.glob(
            os.path.join(args.representative_datapath, '*HR.png'))
    else:
        warnings.warn(
            'No representative input data are given,'
            'bias correction and computation of encodings will be done'
            'on part of all of the low resolution images!')
        bc_lr_data = lr_images_files

        warnings.warn('No representative reference data are given,'
                      'computation of encodings will be done'
                      'on part of all of the high resolution images!')
        comp_encodings_lr_data = lr_images_files
        comp_encodings_hr_data = hr_images_files

    comp_encodings_data = (comp_encodings_lr_data, comp_encodings_hr_data)

    if args.cross_layer_equalization:
        print('Applying cross layer equalization (CLE) to session...')
        gen_sess = equalize_model(
            gen_sess,
            start_op_names=srgan_generator.input.op.name,
            output_op_names=srgan_generator.output.op.name)

    if args.bias_correction:
        print('Applying Bias Correction (BC) to session...')
        # the dataset being evaluated might have varying image sizes
        # so right now only use batch size 1
        batch_size = 1
        num_imgs = len(bc_lr_data)

        quant_params = QuantParams(use_cuda=args.use_cuda,
                                   quant_mode=args.quant_scheme)
        bias_correction_params = BiasCorrectionParams(
            batch_size=batch_size,
            num_quant_samples=min(num_imgs, args.num_quant_samples),
            num_bias_correct_samples=min(num_imgs,
                                         args.num_bias_correct_samples),
            input_op_names=[srgan_generator.input.op.name],
            output_op_names=[srgan_generator.output.op.name])

        ds = make_dataset(bc_lr_data)
        ds = ds.batch(batch_size)

        gen_sess = BiasCorrection.correct_bias(gen_sess,
                                               bias_correction_params,
                                               quant_params, ds)

    # creating quantsim object which inserts quantizer ops
    sim = quantsim.QuantizationSimModel(
        gen_sess,
        starting_op_names=[srgan_generator.input.op.name],
        output_op_names=[srgan_generator.output.op.name],
        quant_scheme=args.quant_scheme,
        default_output_bw=args.default_output_bw,
        default_param_bw=args.default_param_bw)

    # compute activation encodings
    # usually achieves good results when data being used for computing
    # encodings are representative of its task
    partial_eval = partial(evaluate_session,
                           input_name=srgan_generator.input.name,
                           output_name='lambda_3/mul_quantized:0')
    sim.compute_encodings(partial_eval, comp_encodings_data)

    psnr_vals, ssim_vals = evaluate_session(sim.session,
                                            image_files,
                                            srgan_generator.input.name,
                                            'lambda_3/mul_quantized:0',
                                            output_dir=args.output_dir)
    psnr_val = np.mean(psnr_vals)
    ssim_val = np.mean(ssim_vals)

    print(
        f'Mean PSNR and SSIM for given images on quantized model are: [{psnr_val}, {ssim_val}]'
    )