Beispiel #1
0
def train():

    image_lr = tf.placeholder(dtype=tf.float32,
                              shape=(None, 16, 16, 3),
                              name='lr')
    image_hr = tf.placeholder(dtype=tf.float32,
                              shape=(None, 64, 64, 3),
                              name='hr')

    net = WGAN(gamma)

    gen = net.generator(image_lr, bottleneck_num=2)

    real_score = net.discrimintor(gen)
    fake_score = net.discrimintor(image_hr, reuse=True)

    with tf.name_scope('SR_loss'):

        residual = image_hr - gen
        square = tf.abs(residual)
        SR_loss = tf.reduce_mean(square)

        tf.summary.scalar('SR_loss', SR_loss)
    print('test1')

    with tf.name_scope('gan_loss'):

        D_loss = tf.reduce_mean(fake_score) - tf.reduce_mean(real_score)

        G_loss = -tf.reduce_mean(fake_score)

        tf.summary.scalar('G_loss', G_loss)
        tf.summary.scalar('D_loss', D_loss)

        G_overall_loss = gan_ratio * G_loss + SR_loss  # this part might need modification

    print('test2')

    # get variable from G and D
    var_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
    var_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                              'discriminator')

    with tf.name_scope('optim'):

        optim_g = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE)\
            .minimize(G_overall_loss, var_list=var_g)
        optim_d = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE) \
            .minimize(-D_loss, var_list=var_d)

    clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in var_d]

    print('test3')
    # for gradient, var in var_g:
    #     tf.summary.histogram(var.name + '/gradient', gradient)
    #
    # # Add the variables we train to the summary
    # for var in var_g:
    #     tf.summary.histogram(var.name, var)

    # set up logging for tensorboard
    writer = tf.summary.FileWriter(filewriter_path)
    writer.add_graph(tf.get_default_graph())
    summaries = tf.summary.merge_all()

    # saver for storing/restoring checkpoints of the model
    saver = tf.train.Saver()

    data_path = 'train_espcn.tfrecords'

    feature = {
        'train/image_small': tf.FixedLenFeature([], tf.string),
        'train/image_origin': tf.FixedLenFeature([], tf.string)
    }

    # create a list of file names
    filename_queue = tf.train.string_input_producer([data_path],
                                                    num_epochs=NUM_EPOCHS)

    reader = tf.TFRecordReader()
    _, tfrecord_serialized = reader.read(filename_queue)

    features = tf.parse_single_example(tfrecord_serialized, features=feature)

    # Convert the image data from string back to the numbers
    image_blur = tf.decode_raw(features['train/image_small'], tf.uint8)
    image_origin = tf.decode_raw(features['train/image_origin'], tf.uint8)

    image_blur = tf.reshape(image_blur, [32, 32, 3])
    image_origin = tf.reshape(image_origin, [128, 128, 3])

    images, labels = tf.train.shuffle_batch([image_blur, image_origin],
                                            batch_size=BATCH_SIZE,
                                            capacity=30,
                                            num_threads=16,
                                            min_after_dequeue=10)

    images = tf.image.resize_images(images, (16, 16))
    labels = tf.image.resize_images(labels, (64, 64))

    print('test4')

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session() as sess:

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        steps, start_average, end_average = 0, 0, 0
        start_time = time.clock()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for ii in range(NUM_EPOCHS):

            batch_average = 0
            batch_num = int(np.floor(192794 / BATCH_SIZE / 6.0))

            for jj in range(batch_num):

                g_ops = [optim_g, G_loss, summaries]
                d_ops = [optim_d, D_loss]

                for kk in range(critic):

                    steps += 1
                    img_lr, img_hr = sess.run([images, labels])
                    img_lr = (img_lr.astype(np.float32) - 127.5) / 127.5
                    img_hr = (img_hr.astype(np.float32) - 127.5) / 127.5

                    _, loss_d = sess.run(d_ops,
                                         feed_dict={
                                             image_lr: img_lr,
                                             image_hr: img_hr
                                         })

                    sess.run(clip_D)

                steps += 1
                img_lr, img_hr = sess.run([images, labels])
                img_lr = (img_lr.astype(np.float32) - 127.5) / 127.5
                img_hr = (img_hr.astype(np.float32) - 127.5) / 127.5

                _, loss_g, summary = sess.run(g_ops,
                                              feed_dict={
                                                  image_lr: img_lr,
                                                  image_hr: img_hr
                                              })

                # update W_loss and Kt

                writer.add_summary(summary, steps)
                batch_average += loss_g

                if (steps % 100 == 0):
                    print('step: {:d}, G_loss: {:.9f}, D_loss: {:.9f}'.format(
                        steps, loss_g, loss_d))
                    print('time:', time.clock())

            batch_average = float(batch_average) / batch_num

            duration = time.time() - start_time
            print('Epoch: {}, step: {:d}, loss: {:.9f}, '
                  '({:.3f} sec/epoch)'.format(ii, steps, batch_average,
                                              duration))

            start_time = time.time()
            net.save(sess, saver, checkpoint_path, steps)
        coord.request_stop()

        # Wait for threads to stop
        coord.join(threads)
        sess.close()
Beispiel #2
0
def train():

    image_lr = tf.placeholder(dtype=tf.float32,
                              shape=(BATCH_SIZE, 16, 16, 3),
                              name='lr')
    image_hr = tf.placeholder(dtype=tf.float32,
                              shape=(BATCH_SIZE, 64, 64, 3),
                              name='hr')
    #sigma = tf.placeholder(dtype=tf.float32, name='sigma')
    net = WGAN(gamma)

    gen = net.generator(image_lr, bottleneck_num=2)

    real_score = net.discrimintor(gen)
    fake_score = net.discrimintor(image_hr, reuse=True)

    with tf.name_scope('SR_loss'):

        residual = image_hr - gen
        square = tf.abs(residual)
        SR_loss = tf.reduce_mean(square)

        tf.summary.scalar('SR_loss', SR_loss)
    print('test1')

    with tf.name_scope('gan_loss'):

        D_loss = tf.reduce_mean(fake_score) - tf.reduce_mean(real_score)

        G_loss = -tf.reduce_mean(fake_score)

        def interpolate(a, b):
            shape = tf.concat(
                (tf.shape(a)[0:1], tf.tile([1], [a.shape.ndims - 1])), axis=0)
            alpha = tf.random_uniform(shape=shape, minval=0., maxval=1.)
            inter = a + alpha * (b - a)
            inter.set_shape(a.get_shape().as_list())
            return inter

        gp_sample = interpolate(gen, image_hr)

        # sigma = tf.random_uniform(
        #     shape=[BATCH_SIZE, 1],
        #     minval=0.,
        #     maxval=1.
        # )
        #
        # gp_sample = gen*sigma + image_hr*(1 - sigma)

        #gp_sample = tf.reshape(gp_sample, [-1, 128, 128, 3])

        print(gen.get_shape(), 'test2')

        print(image_hr.get_shape())

        print(gp_sample.get_shape())

        gp_gradient = tf.gradients(net.discrimintor(gp_sample, reuse=True),
                                   gp_sample)

        grad_norm = tf.sqrt(
            tf.reduce_sum(tf.square(gp_gradient[0]), reduction_indices=[-1]))

        gp_loss = tf.reduce_mean(tf.square(grad_norm - 1.))

        D_overall_loss = D_loss + gp_rate * gp_loss

        tf.summary.scalar('G_loss', (G_loss))
        tf.summary.scalar('D_loss', (D_loss))
        tf.summary.scalar('GP_loss', gp_loss)

        G_overall_loss = gan_ratio * G_loss + SR_loss  # this part might need modification

    print('test2')

    # get variable from G and D
    var_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
    var_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                              'discriminator')

    with tf.name_scope('optim'):

        optim_g = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE, beta1=0.5, beta2=0.9)\
            .minimize(G_overall_loss, var_list=var_g)
        optim_d = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE, beta1=0.5, beta2=0.9) \
            .minimize(D_overall_loss, var_list=var_d)

    # gradient penalty

    print('test3')
    # for gradient, var in var_g:
    #     tf.summary.histogram(var.name + '/gradient', gradient)
    #
    # # Add the variables we train to the summary
    # for var in var_g:
    #     tf.summary.histogram(var.name, var)

    # set up logging for tensorboard
    writer = tf.summary.FileWriter(filewriter_path)
    writer.add_graph(tf.get_default_graph())
    summaries = tf.summary.merge_all()

    # saver for storing/restoring checkpoints of the model
    saver = tf.train.Saver()

    data_path = 'train_espcn.tfrecords'

    feature = {
        'train/image_small': tf.FixedLenFeature([], tf.string),
        'train/image_origin': tf.FixedLenFeature([], tf.string)
    }

    # create a list of file names
    filename_queue = tf.train.string_input_producer([data_path],
                                                    num_epochs=NUM_EPOCHS)

    reader = tf.TFRecordReader()
    _, tfrecord_serialized = reader.read(filename_queue)

    features = tf.parse_single_example(tfrecord_serialized, features=feature)

    # Convert the image data from string back to the numbers
    image_blur = tf.decode_raw(features['train/image_small'], tf.uint8)
    image_origin = tf.decode_raw(features['train/image_origin'], tf.uint8)

    image_blur = tf.reshape(image_blur, [32, 32, 3])
    image_origin = tf.reshape(image_origin, [128, 128, 3])

    images, labels = tf.train.shuffle_batch([image_blur, image_origin],
                                            batch_size=BATCH_SIZE,
                                            capacity=30,
                                            num_threads=16,
                                            min_after_dequeue=10)

    images = tf.image.resize_images(images, (16, 16))
    labels = tf.image.resize_images(labels, (64, 64))

    print('test4')
    loss_d = 0

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session() as sess:

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        steps, start_average, end_average = 0, 0, 0
        start_time = time.clock()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for ii in range(NUM_EPOCHS):

            batch_average = 0
            batch_num = int(np.floor(192794 / BATCH_SIZE / 6.0))

            for jj in range(batch_num):

                g_ops = [optim_g, G_overall_loss]
                d_ops = [optim_d, D_overall_loss]

                for kk in range(critic):

                    steps += 1

                    img_lr, img_hr = sess.run([images, labels])
                    img_lr = (img_lr.astype(np.float32) - 127.5) / 127.5
                    img_hr = (img_hr.astype(np.float32) - 127.5) / 127.5

                    _, loss_d = sess.run(d_ops,
                                         feed_dict={
                                             image_lr: img_lr,
                                             image_hr: img_hr
                                         })

                steps += 1

                img_lr, img_hr = sess.run([images, labels])
                img_lr = (img_lr.astype(np.float32) - 127.5) / 127.5
                img_hr = (img_hr.astype(np.float32) - 127.5) / 127.5

                _, loss_g = sess.run(g_ops,
                                     feed_dict={
                                         image_lr: img_lr,
                                         image_hr: img_hr
                                     })

                if steps % 10 == 0:
                    summary = sess.run(summaries,
                                       feed_dict={
                                           image_lr: img_lr,
                                           image_hr: img_hr
                                       })
                    writer.add_summary(summary, steps)

                batch_average += loss_g

                if (steps % 100 == 0):
                    print('step: {:d}, G_loss: {:.9f}, D_loss: {:.9f}'.format(
                        steps, loss_g, loss_d))
                    print('time:', time.clock())

            batch_average = float(batch_average) / batch_num

            duration = time.time() - start_time
            print('Epoch: {}, step: {:d}, loss: {:.9f}, '
                  '({:.3f} sec/epoch)'.format(ii, steps, batch_average,
                                              duration))

            start_time = time.time()
            net.save(sess, saver, checkpoint_path, steps)
        coord.request_stop()

        # Wait for threads to stop
        coord.join(threads)
        sess.close()