Exemplo n.º 1
0
    filename_queue = tf.train.string_input_producer(file_names)
    reader = tf.WholeFileReader()
    _, value = reader.read(filename_queue)
    image = tf.image.decode_jpeg(value)
    cropped = tf.random_crop(image, [H * 4, W * 4, 3])
    random_flipped = tf.image.random_flip_left_right(cropped)
    minibatch = tf.cast(
        tf.train.batch([random_flipped], batch_size, capacity=300),
        tf.float32) / 255.0
    resized = tf.image.resize_bicubic(minibatch, [H, W])
    return steps_per_epoch, minibatch, resized


with tf.device('/cpu:0'):
    steps_per_epoch, minibatch, resized = read(filenames)
resnet = srResNet.srResNet(resized * 2.0 - 1)
result = resnet.out
gen_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

dbatch = tf.concat([tf.cast(minibatch, tf.float32), result], 0)
bicubic = tf.clip_by_value(tf.image.resize_bicubic(resized, [H * 4, W * 4]), 0,
                           1)
out = [tf.cast(minibatch, tf.float32), result, bicubic]

vgg = vgg19.Vgg19()
vgg.build(dbatch)
fmap = tf.split(vgg.conv2_2, 2)
content_loss = tf.losses.mean_squared_error(fmap[0], fmap[1])

disc = discriminator.Discriminator(dbatch)
D_x, D_G_z = tf.split(disc.dense2, 2)
Exemplo n.º 2
0
                                                    capacity=1000,
                                                    num_epochs=100)
    reader = tf.WholeFileReader()
    _, value = reader.read(filename_queue)
    image = tf.image.decode_jpeg(value)
    cropped = tf.random_crop(image, [resolution * 4, resolution * 4, 3])
    random_flipped = tf.image.random_flip_left_right(cropped)
    minibatch = tf.train.batch([random_flipped], batch_size, capacity=300)
    rescaled = tf.image.resize_bicubic(minibatch,
                                       [resolution, resolution]) / 127.5 - 1
    return minibatch, rescaled


with tf.device('/cpu:0'):
    minibatch, rescaled = read(filenames)
resnet = srResNet.srResNet(rescaled)
result = (resnet.conv5 + 1) * 127.5
gen_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

dbatch = tf.concat([tf.cast(minibatch, tf.float32), result], 0)
vgg = vgg19.Vgg19()
vgg.build(dbatch)
fmap = tf.split(vgg.conv5_4, 2)
content_loss = tf.losses.mean_squared_error(fmap[0], fmap[1])

disc = discriminator.Discriminator(dbatch)
D_x, D_G_z = tf.split(tf.squeeze(disc.dense2), 2)
adv_loss = tf.reduce_mean(tf.square(D_G_z - 1.0))
gen_loss = (adv_loss + content_loss)
disc_loss = (tf.reduce_mean(tf.square(D_x - 1.0) + tf.square(D_G_z)))
disc_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
Exemplo n.º 3
0
filenames = './testset_path.txt'
H = 28
W = 24
variable_path = './save/srGAN/srgan'
output_path = "./outputdata/test/"
with tf.device('/cpu:0'):
    file_names = open(filenames, 'r').read().split('\n')
    file_names.pop(len(file_names) - 1)

image = tf.placeholder(tf.float32, shape=[1, H * 4, W * 4, 3])
#cropped = tf.random_crop(img,[ H *4, W*4,3])
#random_flipped=tf.image.random_flip_left_right(cropped)
rescaled = tf.image.resize_images(image, [H, W], tf.image.ResizeMethod.BICUBIC)
#bicubic=tf.image.resize_images( rescaled , [H*4, W*4], tf.image.ResizeMethod.BICUBIC )
resnet = srResNet.srResNet(rescaled * (1.0 / 127.5) - 1.0)
result = ((resnet.conv5) + 1) * 127.5
result_arr_rescaled = result * (1.0 / 255.0)
out = tf.clip_by_value(result_arr_rescaled, 0, 1)

config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    loader = tf.train.Saver(
        var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
    loader.restore(sess, variable_path)
    for i in xrange(994):
        img = mpimg.imread(file_names[i])
        print(img.shape)
        outd = sess.run(out, feed_dict={image: [img]})
        io.imsave(output_path + file_names[i][40:], outd[0])