FLAGS = flags.FLAGS
FLAGS._parse_flags()

experiment = os.path.basename(FLAGS.checkpoint_dir)
layer_sizes = [int(k) for k in experiment.split("_")[0].split("-")]
filter_sizes = [int(k) for k in experiment.split("_")[1].split("-")]
print(layer_sizes)
x = tf.placeholder(tf.float32, shape=(None, None, None, 3), name="input")
y = tf.placeholder(tf.float32, shape=(None, None, None, 3), name="label")
is_training = tf.placeholder_with_default(False, (), name='is_training')

model = srcnn.SRCNN(x,
                    y,
                    layer_sizes,
                    filter_sizes,
                    is_training=is_training,
                    device='/cpu:0',
                    input_depth=3,
                    output_depth=3)

saver = tf.train.Saver()
init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())

sess = tf.Session()
sess.run(init_op)

checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
print("checkpoint", checkpoint)
saver.restore(sess, checkpoint)
Esempio n. 2
0
def train():
    with tf.Graph().as_default(), tf.device(FLAGS.device):
        image_obj = SuperResData(imageset='BSD100',
                                 upscale_factor=FLAGS.upscale)
        train_images, train_labels = image_obj.make_patches(
            patch_size=FLAGS.patch_size, stride=FLAGS.stride)
        data_length = len(train_labels)
        print(data_length)
        train_images = np.float32(train_images)
        train_labels = np.float32(train_labels)
        train_images_tensor = tf.constant(train_images,
                                          name='train_images',
                                          dtype=tf.float32)
        train_labels_tensor = tf.constant(train_labels,
                                          name='train_labels',
                                          dtype=tf.float32)

        is_training = tf.placeholder_with_default(True, (), name='is_training')
        x_placeholder = tf.placeholder_with_default(tf.zeros(shape=(1, 10, 10,
                                                                    3),
                                                             dtype=tf.float32),
                                                    shape=(None, None, None,
                                                           3),
                                                    name="input_placeholder")
        y_placeholder = tf.placeholder_with_default(tf.zeros(shape=(1, 20, 20,
                                                                    3),
                                                             dtype=tf.float32),
                                                    shape=(None, None, None,
                                                           3),
                                                    name="label_placeholder")
        x = tf.cond(is_training, lambda: train_images_tensor,
                    lambda: x_placeholder)
        y = tf.cond(is_training, lambda: train_labels_tensor,
                    lambda: y_placeholder)

        x_interp = tf.minimum(tf.nn.relu(x), 255)

        model = srcnn.SRCNN(x_interp,
                            y,
                            FLAGS.HIDDEN_LAYERS,
                            FLAGS.KERNELS,
                            is_training=is_training,
                            input_depth=FLAGS.depth,
                            output_depth=FLAGS.depth,
                            upscale_factor=FLAGS.upscale,
                            learning_rate=1e-4,
                            device=FLAGS.device)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        sess = tf.Session()
        saver = tf.train.Saver()

        sess.run(init_op)

        batch_loss = 0
        summary_loss = [100000]
        update_loss = 0
        counter = 0
        update_check = 20
        whole_loss = []
        for epoch in range(FLAGS.num_epochs):
            batch_inx = (data_length) // FLAGS.batch_size
            for idx in range(batch_inx):
                batch_images = train_images[idx * FLAGS.batch_size:(idx + 1) *
                                            FLAGS.batch_size]
                batch_labels = train_labels[idx * FLAGS.batch_size:(idx + 1) *
                                            FLAGS.batch_size]
                _, train_loss = sess.run([model.opt, model.loss],
                                         feed_dict={
                                             x: batch_images,
                                             y: batch_labels
                                         })

                counter = counter + 1
                batch_loss += train_loss
                update_loss += train_loss
                whole_loss.append(train_loss)
                if idx % update_check == 0:
                    print("Average loss for this update:",
                          round(update_loss / update_check, 3))
                    update_loss = 0

                if counter % 200 == 0:
                    print("Step: %i, Index: %i, Train Loss: %2.4f" %
                          (epoch, idx, train_loss))

            batch_ave_loss = batch_loss / batch_inx
            print("Epoch: ", epoch, "Batch Loss: ", batch_loss,
                  "Batch Average loss", batch_ave_loss)

            summary_loss.append(batch_loss)
            # If the update loss is at a new minimum, save the model
            if batch_loss <= min(summary_loss):
                print('New Record!')
                print("Step: %i, Batch Loss: %2.4f" % (epoch, batch_loss))
                save_path = saver.save(
                    sess, os.path.join(SAVE_DIR, "bestmodel.ckpt"))
            batch_loss = 0
import torch
import cv2
import srcnn
import numpy as np
import glob as glob
import os
from torchvision.utils import save_image

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = srcnn.SRCNN().to(device)
model.load_state_dict(
    torch.load('/home/harshubh/Desktop/SRCNN/outputs/model.pth'))

image_paths = glob.glob('/home/harshubh/Desktop/SRCNN/inputs/bicubic_2x/*')
for image_path in image_paths:
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    test_image_name = image_path.split(os.path.sep)[-1].split('.')[0]
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image = image.reshape(image.shape[0], image.shape[1], 1)
    cv2.imwrite(
        f"/home/harshubh/Desktop/SRCNN/outputs/test_{test_image_name}.png",
        image)
    image = image / 255.  # normalize the pixel values
    cv2.imshow('Greyscale image', image)
    cv2.waitKey(0)
    model.eval()
    with torch.no_grad():
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        image = torch.tensor(image, dtype=torch.float).to(device)
        image = image.unsqueeze(0)
FLAGS = flags.FLAGS
FLAGS._parse_flags()

experiment = os.path.basename(FLAGS.checkpoint_dir)
layer_sizes = [int(k) for k in experiment.split("_")[0].split("-")]
filter_sizes = [int(k) for k in experiment.split("_")[1].split("-")]

x = tf.placeholder(tf.float32, shape=(None, None, None, 3), name="input")
y = tf.placeholder(tf.float32, shape=(None, None, None, 3), name="label")
is_training = tf.placeholder_with_default(False, (), name='is_training')

model = srcnn.SRCNN(x,
                    y,
                    layer_sizes,
                    filter_sizes,
                    is_training=is_training,
                    gpu=False,
                    input_depth=3,
                    output_depth=3)

saver = tf.train.Saver()
init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())

sess = tf.Session()
sess.run(init_op)

checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
print "checkpoint", checkpoint
saver.restore(sess, checkpoint)
FLAGS = flags.FLAGS
FLAGS._parse_flags()

experiment = os.path.basename(FLAGS.checkpoint_dir)
layer_sizes = [int(k) for k in experiment.split("_")[0].split("-")]
filter_sizes = [int(k) for k in experiment.split("_")[1].split("-")]

x = tf.placeholder(tf.float32, shape=(None, None, None, 3), name="input")
y = tf.placeholder(tf.float32, shape=(None, None, None, 3), name="label")
is_training = tf.placeholder_with_default(False, (), name='is_training')

model = srcnn.SRCNN(x,
                    y,
                    layer_sizes,
                    filter_sizes,
                    is_training=is_training,
                    device=FLAGS.device,
                    input_depth=3,
                    output_depth=3)

saver = tf.train.Saver()
init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())

sess = tf.Session()
sess.run(init_op)

checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
print "checkpoint", checkpoint
saver.restore(sess, checkpoint)
Esempio n. 6
0
def train():
    with tf.Graph().as_default(), tf.device("/cpu:0"):
        train_images, train_labels = SuperResData(imageset='BSD100', upscale_factor=FLAGS.upscale)\
                    .tf_patches(batch_size=FLAGS.batch_size)
        test_images_arr, test_labels_arr = SuperResData(imageset='Set5',
                                                        upscale_factor=FLAGS.upscale).get_images()

        # set placeholders, at test time use placeholder
        is_training = tf.placeholder_with_default(True, (), name='is_training')
        x_placeholder = tf.placeholder_with_default(tf.zeros(shape=(1,10,10,3), dtype=tf.float32),
                                                    shape=(None, None, None, 3),
                                                    name="input_placeholder")
        y_placeholder = tf.placeholder_with_default(tf.zeros(shape=(1,20,20,3), dtype=tf.float32),
                                                    shape=(None, None, None, 3),
                                                    name="input_placeholder")


        x = tf.cond(is_training, lambda: train_images, lambda: x_placeholder)
        y = tf.cond(is_training, lambda: train_labels, lambda: y_placeholder)

        # x needs to be interpolated to the shape of y
        h = tf.shape(x)[1] * FLAGS.upscale
        w = tf.shape(x)[2] * FLAGS.upscale
        x_interp = tf.image.resize_bicubic(x, [h,w])
        x_interp = tf.minimum(tf.nn.relu(x_interp),255)

        # build graph
        model = srcnn.SRCNN(x_interp, y, FLAGS.HIDDEN_LAYERS, FLAGS.KERNELS,
                            is_training=is_training, input_depth=FLAGS.depth,
                            output_depth=FLAGS.depth, upscale_factor=FLAGS.upscale,
                            learning_rate=1e-4, device=FLAGS.device)

        def log10(x):
            numerator = tf.log(x)
            denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
            return numerator / denominator

        def luminance(img):
            return 0.299*img[:,:,:,0] + 0.587*img[:,:,:,1] + 0.114*img[:,:,:,2]

        def compute_psnr(x1, x2):
            x1_lum = luminance(x1)
            x2_lum = luminance(x2)
            mse = tf.reduce_mean((x1_lum - x2_lum)**2)
            return 10 * log10(255**2 / mse)

        pred = tf.cast(tf.minimum(tf.nn.relu(model.prediction*255), 255), tf.float32)
        label_scaled = tf.cast(y*255,tf.float32)
        psnr = compute_psnr(pred, label_scaled)
        bic_psnr = compute_psnr(x_interp*255., label_scaled)

        # initialize graph
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # Create a session for running operations in the Graph.
        sess = tf.Session()
        saver = tf.train.Saver()

        # Initialize the variables (the trained variables and the # epoch counter).
        sess.run(init_op)

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for step in range(FLAGS.num_epochs):
            _, train_loss = sess.run([model.opt, model.loss])
            if step % FLAGS.test_step == 0:
                stats = []
                for j, (xtest, ytest) in enumerate(zip(test_images_arr, test_labels_arr)):
                    stats.append(sess.run([bic_psnr], feed_dict={is_training: False, x_placeholder: xtest,
                                    y_placeholder: ytest}))
                print("Step: %i, Train Loss: %2.4f, Test PSNR: %2.4f" %\
                        (step, train_loss, np.mean(stats)))
            if step % FLAGS.save_step == 0:
                save_path = saver.save(sess, os.path.join(SAVE_DIR, "model_%08i.ckpt" % step))
        save_path = saver.save(sess, os.path.join(SAVE_DIR, "model_%08i.ckpt" % step))
Esempio n. 7
0
def train():
    with tf.Graph().as_default(), tf.device("/cpu:0"):
        train_images, train_labels = inputs(True, FLAGS.batch_size,
                                            FLAGS.num_epochs)
        test_images, test_labels = inputs(False, FLAGS.batch_size,
                                          FLAGS.num_epochs)

        # set some placeholders
        x = tf.placeholder(tf.float32,
                           shape=(None, None, None, FLAGS.depth),
                           name="input")
        y = tf.placeholder(tf.float32,
                           shape=(None, None, None, FLAGS.depth),
                           name="label")
        is_training = tf.placeholder_with_default(True, (), name='is_training')

        # build graph
        model = srcnn.SRCNN(x,
                            y,
                            FLAGS.HIDDEN_LAYERS,
                            FLAGS.KERNELS,
                            is_training=is_training,
                            input_depth=FLAGS.depth,
                            output_depth=FLAGS.depth,
                            learning_rate=1e-4,
                            gpu=FLAGS.gpu)

        # initialize graph
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # Create a session for running operations in the Graph.
        sess = tf.Session()
        saver = tf.train.Saver()

        # Initialize the variables (the trained variables and the # epoch counter).
        sess.run(init_op)

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # how many images should we iterate through to test
        if 'set14' in FLAGS.test_dir.lower():
            test_iters = 14
        elif 'set5' in FLAGS.test_dir.lower():
            test_iters = 5
        else:
            test_iters = 1

        # for demo purposes this will keep the code simpler
        # in practice you'd want to feed train_images and train_labels in to SRCNN as a pipline
        def feed_dict(train=True):
            if train:
                im, lab = sess.run([train_images, train_labels])
            else:
                im, lab = sess.run([test_images, test_labels])
            return {x: im, y: lab, is_training: True}

        for step in range(FLAGS.num_epochs):
            _, train_loss = sess.run([model.opt, model.loss],
                                     feed_dict=feed_dict(True))

            if step % FLAGS.test_step == 0:
                for j in range(test_iters):
                    #im, lab = sess.run([test_images, test_labels])
                    test_stats = sess.run([model.loss],
                                          feed_dict=feed_dict(False))
                print("Step: %i, Train Loss: %2.4f, Test Loss: %2.4f" %\
                        (step, train_loss, test_stats[0]))
            if step % FLAGS.save_step == 0:
                save_path = saver.save(
                    sess, os.path.join(SAVE_DIR, "model_%08i.ckpt" % step))
        save_path = saver.save(
            sess, os.path.join(SAVE_DIR, "model_%08i.ckpt" % step))
Esempio n. 8
0
def train():
    # checkpoint = "/Users/will/Desktop/DS1001-Intro-to-data-science/projectbestmodel.ckpt"
    with tf.Graph().as_default(), tf.device(FLAGS.device):
        # train_images, train_labels = SuperResData(imageset='BSD100', upscale_factor=FLAGS.upscale).tf_patches(batch_size=FLAGS.batch_size)
        image_obj = SuperResData(imageset='BSD100',
                                 upscale_factor=FLAGS.upscale)
        train_images, train_labels = image_obj.make_patches(
            patch_size=FLAGS.patch_size, stride=FLAGS.stride)
        data_length = len(train_labels)
        train_images = np.float32(train_images)
        train_labels = np.float32(train_labels)
        train_images_tensor = tf.constant(train_images,
                                          name='train_images',
                                          dtype=tf.float32)
        train_labels_tensor = tf.constant(train_labels,
                                          name='train_labels',
                                          dtype=tf.float32)
        # set placeholders, at test time use placeholder

        is_training = tf.placeholder_with_default(True, (), name='is_training')
        x_placeholder = tf.placeholder_with_default(tf.zeros(shape=(1, 10, 10,
                                                                    3),
                                                             dtype=tf.float32),
                                                    shape=(None, None, None,
                                                           3),
                                                    name="input_placeholder")
        y_placeholder = tf.placeholder_with_default(tf.zeros(shape=(1, 20, 20,
                                                                    3),
                                                             dtype=tf.float32),
                                                    shape=(None, None, None,
                                                           3),
                                                    name="label_placeholder")
        x = tf.cond(is_training, lambda: train_images_tensor,
                    lambda: x_placeholder)
        y = tf.cond(is_training, lambda: train_labels_tensor,
                    lambda: y_placeholder)

        # x_interp = tf.image.resize_bicubic(x, [h, w])
        x_interp = tf.minimum(tf.nn.relu(x), 255)

        # build graph
        model = srcnn.SRCNN(x_interp,
                            y,
                            FLAGS.HIDDEN_LAYERS,
                            FLAGS.KERNELS,
                            is_training=is_training,
                            input_depth=FLAGS.depth,
                            output_depth=FLAGS.depth,
                            upscale_factor=FLAGS.upscale,
                            learning_rate=1e-4,
                            device=FLAGS.device)

        # initialize graph
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # Create a session for running operations in the Graph.
        sess = tf.Session()
        saver = tf.train.Saver()

        # Initialize the variables (the trained variables and the # epoch counter).
        sess.run(init_op)

        # Start input enqueue threads.
        summary_loss = [100000]
        update_loss = 0

        for epoch in range(FLAGS.num_epochs):
            batch_inx = (data_length) // FLAGS.batch_size
            for idx in range(batch_inx):
                batch_images = train_images[idx * FLAGS.batch_size:(idx + 1) *
                                            FLAGS.batch_size]
                batch_labels = train_labels[idx * FLAGS.batch_size:(idx + 1) *
                                            FLAGS.batch_size]
                _, train_loss = sess.run([model.opt, model.loss],
                                         feed_dict={
                                             x: batch_images,
                                             y: batch_labels
                                         })
                # print("Step: %i, Index: %i, Train Loss: %2.4f" % (epoch, idx, train_loss))
                update_loss = train_loss
                if update_loss < min(summary_loss):
                    print('new record')
                    print("Step: %i, Index: %i, Train Loss: %2.4f" %
                          (epoch, idx, train_loss))
                    save_path = saver.save(
                        sess, os.path.join(SAVE_DIR, "bestmodel.ckpt"))
                    summary_loss.append(update_loss)