Exemple #1
0
    # Declare the test data reader
    test_data = test_data_loader(FLAGS)

    inputs_raw = tf.placeholder(tf.float32,
                                shape=[1, None, None, 3],
                                name='inputs_raw')
    targets_raw = tf.placeholder(tf.float32,
                                 shape=[1, None, None, 3],
                                 name='targets_raw')
    path_LR = tf.placeholder(tf.string, shape=[], name='path_LR')
    path_HR = tf.placeholder(tf.string, shape=[], name='path_HR')

    with tf.variable_scope('generator'):
        if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
            gen_output = generator(inputs_raw, 3, reuse=False, FLAGS=FLAGS)
        else:
            raise NotImplementedError('Unknown task!!')

    print('Finish building the network')

    with tf.name_scope('convert_image'):
        # Deprocess the images outputed from the model
        inputs = deprocessLR(inputs_raw)
        targets = deprocess(targets_raw)
        outputs = deprocess(gen_output)

        # Convert back to uint8
        converted_inputs = tf.image.convert_image_dtype(inputs,
                                                        dtype=tf.uint8,
                                                        saturate=True)
Exemple #2
0
    def __init__(self, filenames_train, filenames_dev, FLAGS):

        self.filenames_train = filenames_train
        self.filenames_dev = filenames_dev

        self.dataset_train = tf.data.TFRecordDataset(self.filenames_train)
        self.dataset_train = self.dataset_train.map(parseTFRecordExample)
        self.dataset_train = self.dataset_train.shuffle(buffer_size=10000)
        self.dataset_train = self.dataset_train.batch(FLAGS.batch_size)
        if FLAGS.mode == 'train':
            self.dataset_train = self.dataset_train.repeat(FLAGS.max_epoch)
        else:
            self.dataset_train = self.dataset_train.repeat(1)

        self.iterator_train = self.dataset_train.make_one_shot_iterator()
        # self.iterator_train_handle = session.run(self.iterator_train.string_handle())

        self.dataset_dev = tf.data.TFRecordDataset(self.filenames_dev)
        self.dataset_dev = self.dataset_dev.map(parseTFRecordExample)
        self.dataset_dev = self.dataset_dev.shuffle(buffer_size=10000)
        self.dataset_dev = self.dataset_dev.batch(FLAGS.batch_size)
        if FLAGS.mode == 'train':
            self.dataset_dev = self.dataset_dev.repeat()
        else:
            self.dataset_dev = self.dataset_dev.repeat(1)

        self.iterator_dev = self.dataset_dev.make_one_shot_iterator()
        # self.iterator_dev_handle = session.run(self.iterator_dev.string_handle())

        self.handle = tf.placeholder(tf.string, shape=[])
        self.iterator = tf.data.Iterator.from_string_handle(
            self.handle, self.iterator_train.output_types)

        # TODO: Fix batch_size not being fator of total dataset size
        self.next_batch_HR = self.iterator.get_next()
        self.next_batch_HR.set_shape([
            FLAGS.batch_size, FLAGS.input_size * 4, FLAGS.input_size * 4,
            FLAGS.input_size * 4, 4
        ])

        self.next_batch_LR = ops.filter3d(self.next_batch_HR)
        self.next_batch_LR.set_shape([
            FLAGS.batch_size, FLAGS.input_size, FLAGS.input_size,
            FLAGS.input_size, 4
        ])

        self.FLAGS = FLAGS

        # Build the generator part
        with tf.variable_scope('generator'):
            self.output_channels = self.next_batch_HR.get_shape().as_list()[-1]
            self.gen_output = model.generator(self.next_batch_LR,
                                              self.output_channels,
                                              reuse=False,
                                              FLAGS=FLAGS)
            # self.gen_output.set_shape([FLAGS.batch_size, FLAGS.input_size * 4, FLAGS.input_size * 4, FLAGS.input_size * 4, 4])

        # Build the fake discriminator
        with tf.name_scope('fake_discriminator'):
            with tf.variable_scope('discriminator', reuse=False):
                self.discrim_fake_output = model.discriminator(self.gen_output,
                                                               FLAGS=FLAGS)

        # Build the real discriminator
        with tf.name_scope('real_discriminator'):
            with tf.variable_scope('discriminator', reuse=True):
                self.discrim_real_output = model.discriminator(
                    self.next_batch_HR, FLAGS=FLAGS)

        # Summary
        tf.summary.image("High resolution", self.next_batch_HR[0:1, :, :, 0,
                                                               0:1])
        tf.summary.image("Low resolution", self.next_batch_LR[0:1, :, :, 0,
                                                              0:1])
        tf.summary.image("Generated", self.gen_output[0:1, :, :, 0, 0:1])
        tf.summary.image(
            "Concat",
            tf.concat([
                self.next_batch_HR[0:1, :, :, 0, 0:1],
                self.gen_output[0:1, :, :, 0, 0:1]
            ],
                      axis=2))

        # Calculating the generator loss
        with tf.variable_scope('generator_loss'):

            dx = 2. * np.pi / (4. * self.FLAGS.input_size)

            vel_grad = ops.get_velocity_grad(self.gen_output, dx, dx, dx)
            vel_grad_HR = ops.get_velocity_grad(self.next_batch_HR, dx, dx, dx)
            strain_rate_2_HR = tf.reduce_mean( tf.reduce_mean( tf.reduce_mean( \
                               ops.get_strain_rate_mag2(vel_grad_HR), axis=1, keep_dims=True), \
                               axis=2, keep_dims=True), axis=3, keep_dims=True)

            self.continuity_res = ops.get_continuity_residual(vel_grad)
            self.pressure_res = ops.get_pressure_residual(
                self.gen_output, vel_grad, dx, dx, dx)

            tke_gen = ops.get_TKE(self.gen_output)
            tke_hr = ops.get_TKE(self.next_batch_HR)
            tke_hr_mean2 = tf.reduce_mean( tf.reduce_mean( tf.reduce_mean( \
                           tf.square(tke_hr), axis=1, keep_dims=True), axis=2, keep_dims=True), axis=3, keep_dims=True )
            self.tke_loss = tf.reduce_mean(
                tf.square(tke_gen - tke_hr) / tke_hr_mean2)

            vorticity_gen = ops.get_vorticity(vel_grad)
            vorticity_hr = ops.get_vorticity(vel_grad_HR)
            # self.vorticity_loss = tf.reduce_mean(tf.square(vorticity_gen-vorticity_hr))

            ens_gen = ops.get_enstrophy(vorticity_gen)
            ens_hr = ops.get_enstrophy(vorticity_hr)
            ens_hr_mean2 = tf.reduce_mean( tf.reduce_mean( tf.reduce_mean( \
                           tf.square(ens_hr), axis=1, keep_dims=True), axis=2, keep_dims=True), axis=3, keep_dims=True )
            self.ens_loss = tf.reduce_mean(
                tf.square(ens_gen - ens_hr) / ens_hr_mean2)

            # Compute the euclidean distance between the two features
            mse_hr_mean2 = tf.reduce_mean( tf.reduce_mean( tf.reduce_mean( \
                           tf.square(self.next_batch_HR), axis=1, keep_dims=True), axis=2, keep_dims=True), axis=3, keep_dims=True )
            self.mse_loss = tf.reduce_mean(
                tf.square(self.gen_output - self.next_batch_HR) / mse_hr_mean2)

            # Content loss
            with tf.variable_scope('content_loss'):
                # Content loss => mse + enstrophy
                self.content_loss = (
                    1 - self.FLAGS.lambda_ens
                ) * self.mse_loss + self.FLAGS.lambda_ens * self.ens_loss

            # Physics loss
            with tf.variable_scope('physics_loss'):
                self.continuity_loss = tf.reduce_mean(
                    tf.square(self.continuity_res) / strain_rate_2_HR)
                self.pressure_loss = tf.reduce_mean(
                    tf.square(self.pressure_res) / strain_rate_2_HR**2)

                self.physics_loss = (
                    1 - self.FLAGS.lambda_con
                ) * self.pressure_loss + self.FLAGS.lambda_con * self.continuity_loss

            with tf.variable_scope('adversarial_loss'):
                if (FLAGS.GAN_type == 'GAN'):
                    self.adversarial_loss = tf.reduce_mean(
                        -tf.log(self.discrim_fake_output + FLAGS.EPS))

                if (FLAGS.GAN_type == 'WGAN_GP'):
                    self.adversarial_loss = tf.reduce_mean(
                        -self.discrim_fake_output)

            self.gen_loss = (
                1 - self.FLAGS.lambda_phy
            ) * self.content_loss + self.FLAGS.lambda_phy * self.physics_loss
            self.gen_loss = (
                1 - self.FLAGS.adversarial_ratio) * self.gen_loss + (
                    self.FLAGS.adversarial_ratio) * self.adversarial_loss

        tf.summary.scalar('Generator loss', self.gen_loss)
        tf.summary.scalar('Adversarial loss', self.adversarial_loss)
        tf.summary.scalar('Content loss', self.content_loss)
        tf.summary.scalar('Physics loss', self.physics_loss)
        tf.summary.scalar('MSE error', tf.sqrt(self.mse_loss))
        tf.summary.scalar('Continuity error', tf.sqrt(self.continuity_loss))
        tf.summary.scalar('Pressure error', tf.sqrt(self.pressure_loss))
        tf.summary.scalar('TKE error', tf.sqrt(self.tke_loss))
        # tf.summary.scalar('Vorticity loss', self.vorticity_loss)
        tf.summary.scalar('Enstrophy error', tf.sqrt(self.ens_loss))

        tf.summary.image('Z - Continuity residual',
                         self.continuity_res[0:1, :, :, 0, 0:1])
        tf.summary.image('Z - Pressure residual', self.pressure_res[0:1, :, :,
                                                                    0, 0:1])

        # Create a new instance of the discriminator for gradient penalty
        if (FLAGS.GAN_type == 'WGAN_GP'):
            eps_WGAN = tf.random_uniform(shape=[FLAGS.batch_size, 1, 1, 1, 1],
                                         minval=0.,
                                         maxval=1.)
            inpt_hat = eps_WGAN * self.next_batch_HR + (
                1 - eps_WGAN) * self.gen_output

            # Build the interpolatd discriminator for WGAN-GP
            with tf.name_scope('hat_discriminator'):
                with tf.variable_scope('discriminator', reuse=True):
                    discrim_hat_output = model.discriminator(inpt_hat,
                                                             FLAGS=FLAGS)

        # Calculating the discriminator loss
        with tf.variable_scope('discriminator_loss'):
            if (FLAGS.GAN_type == 'GAN'):
                discrim_fake_loss = tf.log(1 - self.discrim_fake_output +
                                           FLAGS.EPS)
                discrim_real_loss = tf.log(self.discrim_real_output +
                                           FLAGS.EPS)

                self.discrim_loss = tf.reduce_mean(-(discrim_fake_loss +
                                                     discrim_real_loss))

            if (FLAGS.GAN_type == 'WGAN_GP'):
                self.discrim_loss = tf.reduce_mean(self.discrim_fake_output -
                                                   self.discrim_real_output)

                grad_dicrim_inpt_hat = tf.gradients(discrim_hat_output,
                                                    [inpt_hat])[0]

                # L2-Norm across channels
                gradnorm_discrim_inpt_hat = tf.sqrt(
                    tf.reduce_sum(tf.square(grad_dicrim_inpt_hat),
                                  reduction_indices=[-1]))
                gradient_penalty = tf.reduce_mean(
                    (gradnorm_discrim_inpt_hat - 1.)**2)

                self.discrim_loss += FLAGS.lambda_WGAN * gradient_penalty

        tf.summary.scalar('Discriminator loss', self.discrim_loss)

        with tf.variable_scope('get_learning_rate_and_global_step'):

            self.global_step = tf.train.create_global_step()
            self.learning_rate = tf.train.exponential_decay(
                FLAGS.learning_rate,
                self.global_step,
                FLAGS.decay_step,
                FLAGS.decay_rate,
                staircase=FLAGS.stair)
            self.incr_global_step = tf.assign(self.global_step,
                                              self.global_step + 1)

        with tf.variable_scope('dicriminator_train'):

            discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                              scope='discriminator')
            discrim_optimizer = tf.train.AdamOptimizer(self.learning_rate,
                                                       beta1=FLAGS.beta)
            discrim_grads_and_vars = discrim_optimizer.compute_gradients(
                self.discrim_loss, discrim_tvars)
            self.discrim_train = discrim_optimizer.apply_gradients(
                discrim_grads_and_vars)
            # self.discrim_train = discrim_optimizer.minimize( self.discrim_loss, self.global_step )

        with tf.variable_scope('generator_train'):
            gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='generator')

            # Need to wait discriminator to perform train step
            with tf.control_dependencies(
                [self.discrim_train] +
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                gen_optimizer = tf.train.AdamOptimizer(self.learning_rate,
                                                       beta1=FLAGS.beta)
                gen_grads_and_vars = gen_optimizer.compute_gradients(
                    self.gen_loss, gen_tvars)
                self.gen_train = gen_optimizer.apply_gradients(
                    gen_grads_and_vars)
                # self.gen_train = gen_optimizer.minimize( self.gen_loss )

        exp_averager = tf.train.ExponentialMovingAverage(decay=0.99)
        self.update_loss = exp_averager.apply(
            [self.discrim_loss, self.content_loss, self.adversarial_loss])

        # Define data saver
        self.saver = tf.train.Saver(max_to_keep=10)
        self.weights_initializer = tf.train.Saver(discrim_tvars + gen_tvars)
        self.weights_initializer_g = tf.train.Saver(gen_tvars)

        # Summary
        tf.summary.scalar("Discriminator fake output",
                          self.discrim_fake_output[0, 0])
        tf.summary.scalar("Learning rate", self.learning_rate)
        self.merged_summary = tf.summary.merge_all()
Exemple #3
0
import tensorflow as tf
import tensorflow.contrib.slim as slim
import os
from lib.model import  generator, inference_data_loader, save_images
from lib.ops import *
import math
import time
import numpy as np


# Declare the test data reader
inputs_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='inputs_raw')
path_LR = tf.placeholder(tf.string, shape=[], name='path_LR')

with tf.variable_scope('generator'):
        gen_output = generator(inputs_raw, 3, reuse=False)

with tf.name_scope('convert_image'):
    # DeProcess the images output from the model
    inputs = deprocessLR(inputs_raw)
    outputs = deprocess(gen_output)

    # Convert back to uint8
    converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True)
    converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True)

with tf.name_scope('encode_image'):
    save_fetch = {
        "path_LR": path_LR,
        "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name='input_pngs'),
        "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name='output_pngs')
Exemple #4
0
    def __init__(self, filenames_train, filenames_dev, FLAGS):

        self.filenames_train = filenames_train
        self.filenames_dev = filenames_dev

        self.dataset_train = tf.data.TFRecordDataset(self.filenames_train)
        self.dataset_train = self.dataset_train.map(parseTFRecordExample)
        self.dataset_train = self.dataset_train.shuffle(buffer_size=10000)
        self.dataset_train = self.dataset_train.batch(FLAGS.batch_size)
        if FLAGS.mode == 'train':
            self.dataset_train = self.dataset_train.repeat(FLAGS.max_epoch)
        else:
            self.dataset_train = self.dataset_train.repeat(1)

        self.iterator_train = self.dataset_train.make_one_shot_iterator()
        # self.iterator_train_handle = session.run(self.iterator_train.string_handle())

        self.dataset_dev = tf.data.TFRecordDataset(self.filenames_dev)
        self.dataset_dev = self.dataset_dev.map(parseTFRecordExample)
        self.dataset_dev = self.dataset_dev.shuffle(buffer_size=10000)
        self.dataset_dev = self.dataset_dev.batch(FLAGS.batch_size)
        if FLAGS.mode == 'train':
            self.dataset_dev = self.dataset_dev.repeat()
        else:
            self.dataset_dev = self.dataset_dev.repeat(1)

        self.iterator_dev = self.dataset_dev.make_one_shot_iterator()
        # self.iterator_dev_handle = session.run(self.iterator_dev.string_handle())

        self.handle = tf.placeholder(tf.string, shape=[])
        self.iterator = tf.data.Iterator.from_string_handle(
            self.handle, self.iterator_train.output_types)

        # TODO: Fix batch_size not being fator of total dataset size
        self.next_batch_HR = self.iterator.get_next()
        self.next_batch_HR.set_shape([
            FLAGS.batch_size, FLAGS.input_size * 4, FLAGS.input_size * 4,
            FLAGS.input_size * 4, 4
        ])

        self.next_batch_LR = ops.filter3d(self.next_batch_HR)
        self.next_batch_LR.set_shape([
            FLAGS.batch_size, FLAGS.input_size, FLAGS.input_size,
            FLAGS.input_size, 4
        ])

        self.FLAGS = FLAGS

        # Build the generator part
        with tf.variable_scope('generator'):
            self.output_channels = self.next_batch_HR.get_shape().as_list()[-1]
            self.gen_output = model.generator(self.next_batch_LR,
                                              self.output_channels,
                                              reuse=False,
                                              FLAGS=FLAGS)
            # self.gen_output.set_shape([FLAGS.batch_size, FLAGS.input_size * 4, FLAGS.input_size * 4, FLAGS.input_size * 4, 4])

        # Summary
        tf.summary.image("High resolution", self.next_batch_HR[0:1, :, :, 0,
                                                               0:1])
        tf.summary.image("Low resolution", self.next_batch_LR[0:1, :, :, 0,
                                                              0:1])
        tf.summary.image("Generated", self.gen_output[0:1, :, :, 0, 0:1])
        tf.summary.image(
            "Concat",
            tf.concat([
                self.next_batch_HR[0:1, :, :, 0, 0:1],
                self.gen_output[0:1, :, :, 0, 0:1]
            ],
                      axis=2))

        # Calculating the generator loss
        with tf.variable_scope('generator_loss'):

            dx = 2. * np.pi / (4. * self.FLAGS.input_size)

            vel_grad = ops.get_velocity_grad(self.gen_output, dx, dx, dx)
            vel_grad_HR = ops.get_velocity_grad(self.next_batch_HR, dx, dx, dx)
            strain_rate_2_HR = tf.reduce_mean( tf.reduce_mean( tf.reduce_mean( \
                               ops.get_strain_rate_mag2(vel_grad_HR), axis=1, keep_dims=True), \
                               axis=2, keep_dims=True), axis=3, keep_dims=True)

            self.continuity_res = ops.get_continuity_residual(vel_grad)
            self.pressure_res = ops.get_pressure_residual(
                self.gen_output, vel_grad, dx, dx, dx)

            tke_gen = ops.get_TKE(self.gen_output)
            tke_hr = ops.get_TKE(self.next_batch_HR)
            tke_hr_mean2 = tf.reduce_mean( tf.reduce_mean( tf.reduce_mean( \
                           tf.square(tke_hr), axis=1, keep_dims=True), axis=2, keep_dims=True), axis=3, keep_dims=True )
            self.tke_loss = tf.reduce_mean(
                tf.square(tke_gen - tke_hr) / tke_hr_mean2)

            vorticity_gen = ops.get_vorticity(vel_grad)
            vorticity_hr = ops.get_vorticity(vel_grad_HR)
            # self.vorticity_loss = tf.reduce_mean(tf.square(vorticity_gen-vorticity_hr))

            ens_gen = ops.get_enstrophy(vorticity_gen)
            ens_hr = ops.get_enstrophy(vorticity_hr)
            ens_hr_mean2 = tf.reduce_mean( tf.reduce_mean( tf.reduce_mean( \
                           tf.square(ens_hr), axis=1, keep_dims=True), axis=2, keep_dims=True), axis=3, keep_dims=True )
            self.ens_loss = tf.reduce_mean(
                tf.square(ens_gen - ens_hr) / ens_hr_mean2)

            # Compute the euclidean distance between the two features
            mse_hr_mean2 = tf.reduce_mean( tf.reduce_mean( tf.reduce_mean( \
                           tf.square(self.next_batch_HR), axis=1, keep_dims=True), axis=2, keep_dims=True), axis=3, keep_dims=True )
            self.mse_loss = tf.reduce_mean(
                tf.square(self.gen_output - self.next_batch_HR) / mse_hr_mean2)

            # Content loss
            with tf.variable_scope('content_loss'):
                # Content loss => mse + enstrophy
                self.content_loss = (
                    1 - self.FLAGS.lambda_ens
                ) * self.mse_loss + self.FLAGS.lambda_ens * self.ens_loss

            # Physics loss
            with tf.variable_scope('physics_loss'):
                self.continuity_loss = tf.reduce_mean(
                    tf.square(self.continuity_res) / strain_rate_2_HR)
                self.pressure_loss = tf.reduce_mean(
                    tf.square(self.pressure_res) / strain_rate_2_HR**2)

                self.physics_loss = (
                    1 - self.FLAGS.lambda_con
                ) * self.pressure_loss + self.FLAGS.lambda_con * self.continuity_loss

            self.gen_loss = (
                1 - self.FLAGS.lambda_phy
            ) * self.content_loss + self.FLAGS.lambda_phy * self.physics_loss

        tf.summary.scalar('Generator loss', self.gen_loss)
        tf.summary.scalar('Content loss', self.content_loss)
        tf.summary.scalar('Physics loss', self.physics_loss)
        tf.summary.scalar('MSE error', tf.sqrt(self.mse_loss))
        tf.summary.scalar('Continuity error', tf.sqrt(self.continuity_loss))
        tf.summary.scalar('Pressure error', tf.sqrt(self.pressure_loss))
        tf.summary.scalar('TKE error', tf.sqrt(self.tke_loss))
        # tf.summary.scalar('Vorticity loss', self.vorticity_loss)
        tf.summary.scalar('Enstrophy error', tf.sqrt(self.ens_loss))

        tf.summary.image('Z - Continuity residual',
                         self.continuity_res[0:1, :, :, 0, 0:1])
        tf.summary.image('Z - Pressure residual', self.pressure_res[0:1, :, :,
                                                                    0, 0:1])

        # Define the learning rate and global step
        with tf.variable_scope('get_learning_rate_and_global_step'):

            self.global_step = tf.contrib.framework.get_or_create_global_step()
            self.learning_rate = tf.train.exponential_decay(
                FLAGS.learning_rate,
                self.global_step,
                FLAGS.decay_step,
                FLAGS.decay_rate,
                staircase=FLAGS.stair)
            # self.incr_global_step = tf.assign(self.global_step, self.global_step + 1)

        with tf.variable_scope('generator_train'):

            # Need to wait discriminator to perform train step
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):

                gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                              scope='generator')
                gen_optimizer = tf.train.AdamOptimizer(self.learning_rate,
                                                       beta1=FLAGS.beta)
                self.gen_train = gen_optimizer.minimize(
                    self.gen_loss, self.global_step)

        exp_averager = tf.train.ExponentialMovingAverage(decay=0.99)
        self.update_loss = exp_averager.apply([self.content_loss])

        tf.summary.scalar('Learning rate', self.learning_rate)

        # Define data saver
        self.saver = tf.train.Saver(max_to_keep=10)
        self.weights_initializer = tf.train.Saver(gen_tvars)

        self.merged_summary = tf.summary.merge_all()
Exemple #5
0
def infer2():#inputdir, outputdir):
    
    Flags = tf.app.flags

    # The system parameter
    Flags.DEFINE_string('output_dir', './result/', 'The output directory of the checkpoint')
    Flags.DEFINE_string('summary_dir', './result/log/', 'The dirctory to output the summary')
    Flags.DEFINE_string('mode', 'inference', 'The mode of the model train, test.')
    Flags.DEFINE_string('checkpoint', './SRGAN_pre-trained/model-200000', 'If provided, the weight will be restored from the provided checkpoint')
    Flags.DEFINE_boolean('pre_trained_model', True, 'If set True, the weight will be loaded but the global_step will still '
                                                     'be 0. If set False, you are going to continue the training. That is, '
                                                     'the global_step will be initiallized from the checkpoint, too')
    Flags.DEFINE_string('pre_trained_model_type', 'SRResnet', 'The type of pretrained model (SRGAN or SRResnet)')
    Flags.DEFINE_boolean('is_training', False, 'Training => True, Testing => False')
    Flags.DEFINE_string('vgg_ckpt', './vgg19/vgg_19.ckpt', 'path to checkpoint file for the vgg19')
    Flags.DEFINE_string('task', 'SRGAN', 'The task: SRGAN, SRResnet')
    # The data preparing operation
    Flags.DEFINE_integer('batch_size', 16, 'Batch size of the input batch')
    Flags.DEFINE_string('input_dir_LR', './infer/sample/', 'The directory of the input resolution input data')
    Flags.DEFINE_string('input_dir_HR', './data/infer_HR', 'The directory of the high resolution input data')
    Flags.DEFINE_boolean('flip', True, 'Whether random flip data augmentation is applied')
    Flags.DEFINE_boolean('random_crop', True, 'Whether perform the random crop')
    Flags.DEFINE_integer('crop_size', 24, 'The crop size of the training image')
    Flags.DEFINE_integer('name_queue_capacity', 2048, 'The capacity of the filename queue (suggest large to ensure'
                                                      'enough random shuffle.')
    Flags.DEFINE_integer('image_queue_capacity', 2048, 'The capacity of the image queue (suggest large to ensure'
                                                       'enough random shuffle')
    Flags.DEFINE_integer('queue_thread', 10, 'The threads of the queue (More threads can speedup the training process.')
    # Generator configuration
    Flags.DEFINE_integer('num_resblock', 16, 'How many residual blocks are there in the generator')
    # The content loss parameter
    Flags.DEFINE_string('perceptual_mode', 'VGG54', 'The type of feature used in perceptual loss')
    Flags.DEFINE_float('EPS', 1e-12, 'The eps added to prevent nan')
    Flags.DEFINE_float('ratio', 0.001, 'The ratio between content loss and adversarial loss')
    Flags.DEFINE_float('vgg_scaling', 0.0061, 'The scaling factor for the perceptual loss if using vgg perceptual loss')
    # The training parameters
    Flags.DEFINE_float('learning_rate', 0.0001, 'The learning rate for the network')
    Flags.DEFINE_integer('decay_step', 500000, 'The steps needed to decay the learning rate')
    Flags.DEFINE_float('decay_rate', 0.1, 'The decay rate of each decay step')
    Flags.DEFINE_boolean('stair', False, 'Whether perform staircase decay. True => decay in discrete interval.')
    Flags.DEFINE_float('beta', 0.9, 'The beta1 parameter for the Adam optimizer')
    Flags.DEFINE_integer('max_epoch', None, 'The max epoch for the training')
    Flags.DEFINE_integer('max_iter', 1000000, 'The max iteration of the training')
    Flags.DEFINE_integer('display_freq', 20, 'The diplay frequency of the training process')
    Flags.DEFINE_integer('summary_freq', 100, 'The frequency of writing summary')
    Flags.DEFINE_integer('save_freq', 10000, 'The frequency of saving images')

    
    FLAGS = Flags.FLAGS
    #FLAGS.input_dir_LR = inputdir
    #FLAGS.output_dir = outputdir

    # Print the configuration of the model
    print_configuration_op(FLAGS)

    # Check the output_dir is given
    if FLAGS.output_dir is None:
        raise ValueError('The output directory is needed')

    # Check the output directory to save the checkpoint
    if not os.path.exists(FLAGS.output_dir):
        os.mkdir(FLAGS.output_dir)

    # Check the summary directory to save the event
    if not os.path.exists(FLAGS.summary_dir):
        os.mkdir(FLAGS.summary_dir)


    # Check the checkpoint
    if FLAGS.checkpoint is None:
        raise ValueError('The checkpoint file is needed to performing the test.')

    # In the testing time, no flip and crop is needed
    if FLAGS.flip == True:
        FLAGS.flip = False

    if FLAGS.crop_size is not None:
        FLAGS.crop_size = None

    # Declare the test data reader
    #inference_data = inference_data_loader2(FLAGS, inputdir)

    inputs_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='inputs_raw')
    path_LR = tf.placeholder(tf.string, shape=[], name='path_LR')

    with tf.variable_scope('generator'):
        if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
            gen_output = generator(inputs_raw, 3, reuse=False, FLAGS=FLAGS)
        else:
            raise NotImplementedError('Unknown task!!')

    print('Finish building the network')

    with tf.name_scope('convert_image'):
        # Deprocess the images outputed from the model
        inputs = deprocessLR(inputs_raw)
        outputs = deprocess(gen_output)

        # Convert back to uint8
        converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True)
        converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True)

    with tf.name_scope('encode_image'):
        save_fetch = {
            "path_LR": path_LR,
            "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name='input_pngs'),
            "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name='output_pngs')
        }

    # Define the weight initiallizer (In inference time, we only need to restore the weight of the generator)
    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
    weight_initiallizer = tf.train.Saver(var_list)

    # Define the initialization operation
    init_op = tf.global_variables_initializer()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # Load the pretrained model
        print('Loading weights from the pre-trained model')
        weight_initiallizer.restore(sess, FLAGS.checkpoint)




        dire0 = "./lowresdata/train/"
        folder = sorted(glob(dire0 + "*/"))
        folders = folder[:]

        num = 7
        for inputf in folders:
            outputdir = inputf[0:num] + '2' + inputf[num:]
            if not os.path.exists(outputdir):
                os.makedirs(outputdir)

            inference_data = inference_data_loader2(FLAGS, inputf)


            max_iter = len(inference_data.inputs)
            print('Evaluation starts for ', inputf)
            for i in range(max_iter):
                input_im = np.array([inference_data.inputs[i]]).astype(np.float32)
                path_lr = inference_data.paths_LR[i]
                results = sess.run(save_fetch, feed_dict={inputs_raw: input_im, path_LR: path_lr})
                filesets = save_images2(outputdir ,results, FLAGS)
                for i, f in enumerate(filesets):
                    print('evaluate image', f['name'])

    delflags(FLAGS)
Exemple #6
0
    data_HR = scio.loadmat(r'./data/test.mat')
    data_LR = scio.loadmat(r'./data/test_ds.mat')
    HR = data_HR['test']
    LR = data_LR['test_ds']

    LR = np.expand_dims(LR, axis=4)
    OR = LR
    LR = np.transpose(LR, axes=[0, 3, 1, 2, 4])

    HR = HR.astype(np.float32)
    LR = LR.astype(np.float32)


    with tf.variable_scope('generator'):
        if FLAGS.task == '_3DSRGAN' or FLAGS.task == '_3DSRResnet':
            gen_output = generator(LR[[0]], 191, reuse=False, FLAGS=FLAGS)
        else:
            raise NotImplementedError('Unknown task!!')

    print('Finish building the network')

    # Define the weight initiallizer (In inference time, we only need to restore the weight of the generator)
    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
    weight_initiallizer = tf.train.Saver(var_list)

    # Define the initialization operation
    init_op = tf.global_variables_initializer()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
Exemple #7
0
        FLAGS.flip = False

    if FLAGS.crop_size is not None:
        FLAGS.crop_size = None

    # Declare the test data reader
    test_data = test_data_loader(FLAGS)

    inputs_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='inputs_raw')
    targets_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='targets_raw')
    path_LR = tf.placeholder(tf.string, shape=[], name='path_LR')
    path_HR = tf.placeholder(tf.string, shape=[], name='path_HR')

    with tf.variable_scope('generator'):
        if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
            gen_output = generator(inputs_raw, 3, reuse=False, FLAGS=FLAGS)
        else:
            raise NotImplementedError('Unknown task!!')

    print('Finish building the network')

    with tf.name_scope('convert_image'):
        # Deprocess the images outputed from the model
        inputs = deprocessLR(inputs_raw)
        targets = deprocess(targets_raw)
        outputs = deprocess(gen_output)

        # Convert back to uint8
        converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True)
        converted_targets = tf.image.convert_image_dtype(targets, dtype=tf.uint8, saturate=True)
        converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True)
        initials = None
    elif FLAGS.initials == 'content':
        if FLAGS.content_dir is None:
            raise ValueError('The content image path is needed')
elif FLAGS.task_mode == 'style_transfer':
    contents = initials
    if FLAGS.initials == 'content':
        initials = contents
    elif FLAGS.initials == 'noise':
        initials = None
    elif FLAGS.initials == 'style':
        initials = targets
        
with tf.variable_scope('generator'):
    if FLAGS.task_mode == 'texture_synthesis': 
        gen_output = generator(FLAGS, targets, initials, None, reuse = False)
    elif FLAGS.task_mode == 'style_transfer':
        gen_output = generator(FLAGS, targets, initials, contents, reuse = False)

    # Calculating the generator loss
with tf.name_scope('generator_loss'):   
    with tf.name_scope('tv_loss'):
        tv_loss = total_variation_loss(gen_output)

    with tf.name_scope('style_loss'):
        _, vgg_gen_output = vgg_19(gen_output,is_training=False,reuse=False)
        _, vgg_tar_output = vgg_19(targets,is_training=False,reuse=True)
        style_layer_list = get_layer_list(FLAGS.top_style_layer,False)
        sl = tf.zeros([])
        ratio_list=[100.0, 1.0, 0.1, 0.0001, 1.0, 100.0]
        for i in range(len(style_layer_list)):