def rebuild(img_name):
    """
    图像超分辨率重建
    :return:
    """

    lr_image = Image.open(os.path.join(TEST_IMAGE_DIR, img_name))
    lr_image = np.asarray(lr_image)
    # lr_image, ori_image = prepare_data.preprocess_img(os.path.join(TEST_IMAGE_DIR,img_name))
    try:
        image_height, image_width, _ = lr_image.shape
    except:
        print("error", lr_image.shape)
        return
    with tf.Session() as sess:
        espcn = ESPCN(sess,
                      is_train=False,
                      image_height=image_height,
                      image_width=image_width,
                      image_channel=prepare_data.IMAGE_CHANNEl,
                      ratio=prepare_data.RATIO)
        sr_image = espcn.generate(lr_image / 255.0)

    # otherwise there would be error for image.save
    if not os.path.isdir(TEST_RESULT_DIR):
        os.makedirs(TEST_RESULT_DIR)
    # sr image
    # util.show_img_from_array(sr_image)
    util.save_img_from_array(sr_image,
                             TEST_RESULT_DIR + img_name.split('.')[0] + '.png')
    print("saved")
Esempio n. 2
0
def rebuild(img_name):
    """
    图像超分辨率重建
    :return:
    """
    lr_image, ori_image = prepare_data.preprocess_img(TEST_IMAGE_DIR +
                                                      img_name)
    try:
        image_height, image_width, _ = lr_image.shape
    except:
        return
    with tf.Session() as sess:
        espcn = ESPCN(sess,
                      is_train=False,
                      image_height=image_height,
                      image_width=image_width,
                      image_channel=prepare_data.IMAGE_CHANNEl,
                      ratio=prepare_data.RATIO)
        sr_image = espcn.generate(lr_image / 255.0)

    # lr image
    # util.show_img_from_array(lr_image)
    util.save_img_from_array(
        lr_image, TEST_RESULT_DIR + img_name.split('.')[0] + '_lr.' +
        img_name.split('.')[-1])
    # original image
    # util.show_img_from_array(ori_image)
    util.save_img_from_array(
        ori_image, TEST_RESULT_DIR + img_name.split('.')[0] + '_hr.' +
        img_name.split('.')[-1])
    # sr image
    # util.show_img_from_array(sr_image)
    util.save_img_from_array(
        sr_image, TEST_RESULT_DIR + img_name.split('.')[0] + '_sr.' +
        img_name.split('.')[-1])
Esempio n. 3
0
def train():
    """
    шонч╗Г
    :return:
    """
    prepare_data.prepare_data()
    with tf.Session() as sess:
        espcn = ESPCN(sess,
                      is_train=True,
                      image_height=prepare_data.IMAGE_SIZE,
                      image_width=prepare_data.IMAGE_SIZE,
                      image_channel=prepare_data.IMAGE_CHANNEl,
                      ratio=prepare_data.RATIO)
        espcn.train()
Esempio n. 4
0
def rebuild(img_list):
    """
    图像超分辨率重建
    :return:
    """

    with tf.Session() as sess:
        espcn = ESPCN(
            sess,
            is_train=False,
            #   image_height=image_height,
            #   image_width=image_width,
            image_channel=prepare_data.IMAGE_CHANNEl,
            ratio=prepare_data.RATIO)
        espcn.generate(img_list)
Esempio n. 5
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.gpu_mode and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --gpu_mode=False")

    # model
    if args.model_name == 'SRCNN':
        net = SRCNN(args)
    elif args.model_name == 'VDSR':
        net = VDSR(args)
    elif args.model_name == 'DRCN':
        net = DRCN(args)
    elif args.model_name == 'ESPCN':
        net = ESPCN(args)
    # elif args.model_name == 'FastNeuralStyle':
    #     net = FastNeuralStyle(args)
    elif args.model_name == 'FSRCNN':
        net = FSRCNN(args)
    elif args.model_name == 'SRGAN':
        net = SRGAN(args)
    elif args.model_name == 'LapSRN':
        net = LapSRN(args)
    # elif args.model_name == 'EnhanceNet':
    #     net = EnhanceNet(args)
    elif args.model_name == 'EDSR':
        net = EDSR(args)
    # elif args.model_name == 'EnhanceGAN':
    #     net = EnhanceGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.model_name)

    # train
    net.train()

    # test
    #net.test()

    lr_dir = 'testing_lr_images'
    lr_images = os.listdir(lr_dir)
    for i, image_name in enumerate(lr_images):
        full_name = os.path.join(lr_dir, image_name)
        recon_img = net.test_single(full_name)
        recon_img.save(
            '/content/drive/MyDrive/Colab Notebooks/HW4/srgan_results/' +
            image_name)
Esempio n. 6
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.gpu_mode and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --gpu_mode=False")

    # model
    if args.model_name == 'SRCNN':
        net = SRCNN(args)
    elif args.model_name == 'VDSR':
        net = VDSR(args)
    elif args.model_name == 'DRCN':
        net = DRCN(args)
    elif args.model_name == 'ESPCN':
        net = ESPCN(args)
    # elif args.model_name == 'FastNeuralStyle':
    #     net = FastNeuralStyle(args)
    elif args.model_name == 'FSRCNN':
        net = FSRCNN(args)
    elif args.model_name == 'SRGAN':
        net = SRGAN(args)
    elif args.model_name == 'LapSRN':
        net = LapSRN(args)
    # elif args.model_name == 'EnhanceNet':
    #     net = EnhanceNet(args)
    elif args.model_name == 'EDSR':
        net = EDSR(args)
    # elif args.model_name == 'EnhanceGAN':
    #     net = EnhanceGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.model_name)

    # train
    net.train()

    # test
    net.test()
Esempio n. 7
0
def generate():
    args = get_arguments()

    with open("./params2.json", 'r') as f:
        params = json.load(f)

    if check_params(args, params) == False:
        return

    sess = tf.Session()

    net = ESPCN(filters_size=params['filters_size'],
                channels=params['channels'],
                ratio=params['ratio'],
                batch_size=1,
                lr_size=params['lr_size'],
                edge=params['edge'])

    loss, images, labels = net.build_model()

    #rgb
    lr_image = tf.placeholder(tf.uint8)

    saver = tf.train.Saver()
    try:
        model_loaded = net.load(sess, saver, args.checkpoint)
    except:
        raise Exception(
            "Failed to load model, does the ratio in params2.json match the ratio you trained your checkpoint with?"
        )

    if model_loaded:
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")
        return

    def run_net(num):
        start0 = time.time()
        start1 = time.time()
        sr_image_y_data = sess.run(sr_image,
                                   feed_dict={lr_image: lr_image_batch})
        print('1', time.time() - start1)  #0.1967

        start2 = time.time()
        # pixel shuffle  b c r^2 h w ---> b c rh rw
        sr_image_y_data = shuffle(sr_image_y_data[0], params['ratio'])
        print('2', time.time() - start2)  #0.2775
        start3 = time.time()
        sr_image_ycbcr_data = misc.imresize(
            lr_image_ycbcr_data,
            params['ratio'] * np.array(lr_image_data.shape[0:2]), 'bicubic')
        edge = params['edge'] * params['ratio'] // 2

        sr_image_ycbcr_data = np.concatenate(
            (sr_image_y_data, sr_image_ycbcr_data[edge:-edge, edge:-edge,
                                                  1:3]),
            axis=2)
        print('3', time.time() - start3)  #0.0219
        start4 = time.time()
        #sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)
        print('4', time.time() - start4)  #3.7009   86.59%

        # start5 = time.time()
        # # res_image = cv2.cvtColor(sr_image_data, cv2.COLOR_BGR2YUV)
        print(type(sr_image_ycbcr_data))
        fw = open("result/res.yuv", 'ab')
        fw.write(sr_image_ycbcr_data)
        end = time.time()

        #print(sr_image_data.shape)
        # cv2.namedWindow('show_sr', 0)
        # cv2.imshow('show_sr', sr_image_ycbcr_data)
        # cv2.waitKey(50)
        # #print('5', time.time() - start5) #0.0767
        # cv2.imwrite(args.out_path + '_' + str(num) + '.jpg', sr_image_ycbcr_data)
        #misc.imsave(args.out_path + '_' + str(num) + '.png', sr_image_data)
        print("{:f} seconds".format(time.time() - start0))  #4.2739

    if args.lr_image[-3:] == 'yuv':
        width = 352
        height = 288
        # #lr_image_yuv_data = data[0][0]
        # #lr_image_yuv_data = misc.imread(args.lr_image)
        # #print(type(args.lr_image))
        # lr_image_yuv_data = yuv_import(args.lr_image, (height, width), 1, 0)
        # print(lr_image_yuv_data)
        # lr_image_y_data = lr_image_yuv_data
        # #print(lr_image_y_data.shape)
        # # lr_image_cb_data = lr_image_yuv_data[:, :, 1:2]
        # # lr_image_cr_data = lr_image_yuv_data[:, :, 2:3]
        # lr_image_batch = np.zeros((1,) + lr_image_y_data.shape)
        # lr_image_batch[0] = lr_image_y_data
        fp = open(args.lr_image, 'rb')
        framesize = height * width * 3 // 2  # 一帧图像所含的像素个数
        h_h = height // 2
        h_w = width // 2

        fp.seek(0, 2)  # 设置文件指针到文件流的尾部
        ps = fp.tell()  # 当前文件指针位置
        numfrm = ps // framesize  # 计算输出帧数
        fp.seek(framesize * 0, 0)

        for i in range(10 - 0):
            Yt = np.zeros(shape=(height, width), dtype='uint8', order='C')
            Ut = np.zeros(shape=(h_h, h_w), dtype='uint8', order='C')
            Vt = np.zeros(shape=(h_h, h_w), dtype='uint8', order='C')

            for m in range(height):
                for n in range(width):
                    Yt[m, n] = ord(fp.read(1))
            for m in range(h_h):
                for n in range(h_w):
                    Ut[m, n] = ord(fp.read(1))
            for m in range(h_h):
                for n in range(h_w):
                    Vt[m, n] = ord(fp.read(1))

            img = np.concatenate(
                (Yt.reshape(-1), Ut.reshape(-1), Vt.reshape(-1)))
            img = img.reshape((height * 3 // 2, width)).astype('uint8')

            # yuv2rgb
            bgr_img = cv2.cvtColor(img, cv2.COLOR_YUV2BGR_I420)
            # print(bgr_img)
            cv2.namedWindow('show', 0)
            cv2.imshow('show', bgr_img)
            #cv2.waitKey(10)
            # cv2.imwrite('result/007.jpg', bgr_img)
            # cv2.imwrite('yuv2bgr/%d.jpg' % (i + 1), bgr_img)
            # print("Extract frame %d " % (i + 1))
            lr_image_data = bgr_img
            # print(lr_image_data)
            lr_image_data = lr_image_data.reshape(
                lr_image_data.shape[0],
                lr_image_data.shape[1],
                3,
            )
            print(lr_image_data.shape)
            lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
            lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
            print('************************')
            print(lr_image_y_data.shape)
            lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
            lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
            lr_image_batch = np.zeros((1, ) + lr_image_y_data.shape)
            lr_image_batch[0] = lr_image_y_data

            sr_image = net.generate(lr_image)
            run_net(i)

    else:
        #imghdr.what(args.lr_image) == 'jpeg' or imghdr.what(args.lr_image) == 'png' or imghdr.what(args.lr_image) == 'bmp':
        lr_image_data = misc.imread(args.lr_image)
        #print(lr_image_data)
        lr_image_data = lr_image_data.reshape(
            lr_image_data.shape[0],
            lr_image_data.shape[1],
            3,
        )
        print(lr_image_data.shape)
        lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
        lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
        print('************************')
        print(lr_image_y_data.shape)
        lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
        lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
        lr_image_batch = np.zeros((1, ) + lr_image_y_data.shape)
        lr_image_batch[0] = lr_image_y_data
        sr_image = net.generate(lr_image)
Esempio n. 8
0
def train():
    args = get_arguments()

    with open("./params.json", 'r') as f:
        params = json.load(f)

    if check_params(args, params) == False:
        return

    logdir_root = args.logdir_root # ./logdir
    if logdir_root == LOGDIR_ROOT:
        logdir_root = logdir_root.format(params['ratio']) # ./logdir_{RATIO}x
    logdir = os.path.join(logdir_root, 'train') # ./logdir_{RATIO}x/train

    # Load training data as np arrays
    lr_images, hr_labels = create_inputs(params)

    net = ESPCN(filters_size=params['filters_size'],
                   channels=params['channels'],
                   ratio=params['ratio'],
                   batch_size=args.batch_size,
                   lr_size=params['lr_size'],
                   edge=params['edge'])

    loss, images, labels = net.build_model()
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # set up logging for tensorboard
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    summaries = tf.summary.merge_all()

    # set up session
    sess = tf.Session()

    # saver for storing/restoring checkpoints of the model
    saver = tf.train.Saver()

    init = tf.initialize_all_variables()
    sess.run(init)

    if net.load(sess, saver, logdir):
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")

    try:
        steps, start_average, end_average = 0, 0, 0
        start_time = time.time()
        for ep in range(1, args.epochs + 1):
            batch_idxs = len(lr_images) // args.batch_size
            batch_average = 0
            for idx in range(0, batch_idxs):
                # On the fly batch generation instead of Queue to optimize GPU usage
                batch_images = lr_images[idx * args.batch_size : (idx + 1) * args.batch_size]
                batch_labels = hr_labels[idx * args.batch_size : (idx + 1) * args.batch_size]
                
                steps += 1
                summary, loss_value, _ = sess.run([summaries, loss, optim], feed_dict={images: batch_images, labels: batch_labels})
                writer.add_summary(summary, steps)
                batch_average += loss_value

            # Compare loss of first 20% and last 20%
            batch_average = float(batch_average) / batch_idxs
            if ep < (args.epochs * 0.2):
                start_average += batch_average
            elif ep >= (args.epochs * 0.8):
                end_average += batch_average

            duration = time.time() - start_time
            print('Epoch: {}, step: {:d}, loss: {:.9f}, ({:.3f} sec/epoch)'.format(ep, steps, batch_average, duration))
            start_time = time.time()
            net.save(sess, saver, logdir, steps)
    except KeyboardInterrupt:
        print()
    finally:
        start_average = float(start_average) / (args.epochs * 0.2)
        end_average = float(end_average) / (args.epochs * 0.2)
        print("Start Average: [%.6f], End Average: [%.6f], Improved: [%.2f%%]" \
          % (start_average, end_average, 100 - (100*end_average/start_average)))
Esempio n. 9
0
def generate():
    args = get_arguments()

    with open("./params2.json", 'r') as f:
        params = json.load(f)

    if check_params(args, params) == False:
        return

    sess = tf.Session()

    net = ESPCN(filters_size=params['filters_size'],
                   channels=params['channels'],
                   ratio=params['ratio'],
                   batch_size=1,
                   lr_size=params['lr_size'],
                   edge=params['edge'])

    loss, images, labels = net.build_model()

    files = [f for f in os.listdir(args.lr_image_dir) if os.path.isfile(os.path.join(args.lr_image_dir, f))]

    saver = tf.train.Saver()
    if net.load(sess, saver, args.checkpoint):
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")
        return

    frame_range = (87, 10000)

    for fileName in files:
        try:
            ts = time()
            frame_cnt = int(fileName[5:10])
            if frame_cnt < frame_range[0] or frame_cnt > frame_range[1]:
                print('Ignoring frame ' + str(frame_cnt))
                continue
            else:
                print('start sr for frame ' + str(frame_cnt))

            input_file = os.path.join(args.lr_image_dir, fileName)
            output_file = os.path.join(args.out_path_dir, fileName)

            lr_image = tf.placeholder(tf.uint8)
            lr_image_data = misc.imread(input_file) # pip install pillow
            lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
            lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
            lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
            lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
            lr_image_batch = np.zeros((1,) + lr_image_y_data.shape)
            lr_image_batch[0] = lr_image_y_data
            print('preprocessed %d ms' % ((time()-ts)*1000))
            ts = time()

            sr_image = net.generate(lr_image)
            print('network generated %d ms' % ((time()-ts)*1000))
            ts = time()


            sr_image_y_data = sess.run(sr_image, feed_dict={lr_image: lr_image_batch})

            print('run %d ms' % ((time()-ts)*1000))
            ts = time()

            sr_image_y_data = shuffle(sr_image_y_data[0], args.ratio)
            sr_image_ycbcr_data = misc.imresize(lr_image_ycbcr_data,
                                            params['ratio'] * np.array(lr_image_data.shape[0:2]),
                                            'bicubic')

            edge = params['edge'] * params['ratio'] / 2

            sr_image_ycbcr_data = np.concatenate((sr_image_y_data, sr_image_ycbcr_data[edge:-edge,edge:-edge,1:3]), axis=2)
            print('mixed %d ms' % ((time()-ts)*1000))
            ts = time()
            sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)
            #sr_image_data = sr_image_ycbcr_data
            print('converted %d ms' % ((time()-ts)*1000))
            ts = time()

            misc.imsave(output_file, sr_image_data)
            print(output_file + ' generated %d ms' % ((time()-ts)*1000))
            ts = time()

            if args.hr_image_dir != None:
                hr_image_path = os.path.join(args.hr_image_dir, fileName)
                hr_image_data = misc.imread(hr_image_path)
                model_psnr = psnr(hr_image_data, sr_image_data, edge)
                print('PSNR of the model: {:.2f}dB'.format(model_psnr))

                sr_image_bicubic_data = misc.imresize(lr_image_data,
                                                params['ratio'] * np.array(lr_image_data.shape[0:2]),
                                                'bicubic')
                bicubic_path = os.path.join(args.out_path_dir, fileName + '_bicubic.png')
                misc.imsave(bicubic_path, sr_image_bicubic_data)
                bicubic_psnr = psnr(hr_image_data, sr_image_bicubic_data, 0)
                print('PSNR of Bicubic: {:.2f}dB'.format(bicubic_psnr))
        except IndexError:
            print('Index error caught')
        except IOError:
            print('Cannot identify image file: ' + fileName)
        except ValueError:
            print('Cannot parse file name: ' + fileName)
Esempio n. 10
0
def generate():
    args = get_arguments()

    with open("./params.json", 'r') as f:
        params = json.load(f)

    if check_params(args, params) == False:
        return

    sess = tf.Session()

    net = ESPCN(filters_size=params['filters_size'],
                channels=params['channels'],
                ratio=params['ratio'],
                batch_size=1,
                lr_size=params['lr_size'],
                edge=params['edge'])

    loss, images, labels = net.build_model()

    lr_image = tf.placeholder(tf.uint8)
    lr_image_data = misc.imread(args.lr_image)
    lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
    lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
    lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
    lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
    lr_image_batch = np.zeros((1, ) + lr_image_y_data.shape)
    lr_image_batch[0] = lr_image_y_data

    sr_image = net.generate(lr_image)

    saver = tf.train.Saver()
    try:
        model_loaded = net.load(sess, saver, args.checkpoint)
    except:
        raise Exception(
            "Failed to load model, does the ratio in params.json match the ratio you trained your checkpoint with?"
        )

    if model_loaded:
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")
        return

    sr_image_y_data = sess.run(sr_image, feed_dict={lr_image: lr_image_batch})

    sr_image_y_data = shuffle(sr_image_y_data[0], params['ratio'])
    sr_image_ycbcr_data = misc.imresize(
        lr_image_ycbcr_data,
        params['ratio'] * np.array(lr_image_data.shape[0:2]), 'bicubic')

    edge = params['edge'] * params['ratio'] / 2

    sr_image_ycbcr_data = np.concatenate(
        (sr_image_y_data, sr_image_ycbcr_data[edge:-edge, edge:-edge, 1:3]),
        axis=2)
    sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)

    misc.imsave(args.out_path + '.png', sr_image_data)

    if args.hr_image != None:
        hr_image_data = misc.imread(args.hr_image)
        model_psnr = psnr(hr_image_data, sr_image_data, edge)
        print('PSNR of the model: {:.2f}dB'.format(model_psnr))

        sr_image_bicubic_data = misc.imresize(
            lr_image_data,
            params['ratio'] * np.array(lr_image_data.shape[0:2]), 'bicubic')
        misc.imsave(args.out_path + '_bicubic.png', sr_image_bicubic_data)
        bicubic_psnr = psnr(hr_image_data, sr_image_bicubic_data, 0)
        print('PSNR of Bicubic: {:.2f}dB'.format(bicubic_psnr))