Ejemplo n.º 1
0
def rebuild(img_list):
    """
    图像超分辨率重建
    :return:
    """

    with tf.Session() as sess:
        espcn = ESPCN(
            sess,
            is_train=False,
            #   image_height=image_height,
            #   image_width=image_width,
            image_channel=prepare_data.IMAGE_CHANNEl,
            ratio=prepare_data.RATIO)
        espcn.generate(img_list)
def rebuild(img_name):
    """
    图像超分辨率重建
    :return:
    """

    lr_image = Image.open(os.path.join(TEST_IMAGE_DIR, img_name))
    lr_image = np.asarray(lr_image)
    # lr_image, ori_image = prepare_data.preprocess_img(os.path.join(TEST_IMAGE_DIR,img_name))
    try:
        image_height, image_width, _ = lr_image.shape
    except:
        print("error", lr_image.shape)
        return
    with tf.Session() as sess:
        espcn = ESPCN(sess,
                      is_train=False,
                      image_height=image_height,
                      image_width=image_width,
                      image_channel=prepare_data.IMAGE_CHANNEl,
                      ratio=prepare_data.RATIO)
        sr_image = espcn.generate(lr_image / 255.0)

    # otherwise there would be error for image.save
    if not os.path.isdir(TEST_RESULT_DIR):
        os.makedirs(TEST_RESULT_DIR)
    # sr image
    # util.show_img_from_array(sr_image)
    util.save_img_from_array(sr_image,
                             TEST_RESULT_DIR + img_name.split('.')[0] + '.png')
    print("saved")
Ejemplo n.º 3
0
def rebuild(img_name):
    """
    图像超分辨率重建
    :return:
    """
    lr_image, ori_image = prepare_data.preprocess_img(TEST_IMAGE_DIR +
                                                      img_name)
    try:
        image_height, image_width, _ = lr_image.shape
    except:
        return
    with tf.Session() as sess:
        espcn = ESPCN(sess,
                      is_train=False,
                      image_height=image_height,
                      image_width=image_width,
                      image_channel=prepare_data.IMAGE_CHANNEl,
                      ratio=prepare_data.RATIO)
        sr_image = espcn.generate(lr_image / 255.0)

    # lr image
    # util.show_img_from_array(lr_image)
    util.save_img_from_array(
        lr_image, TEST_RESULT_DIR + img_name.split('.')[0] + '_lr.' +
        img_name.split('.')[-1])
    # original image
    # util.show_img_from_array(ori_image)
    util.save_img_from_array(
        ori_image, TEST_RESULT_DIR + img_name.split('.')[0] + '_hr.' +
        img_name.split('.')[-1])
    # sr image
    # util.show_img_from_array(sr_image)
    util.save_img_from_array(
        sr_image, TEST_RESULT_DIR + img_name.split('.')[0] + '_sr.' +
        img_name.split('.')[-1])
Ejemplo n.º 4
0
def generate():
    args = get_arguments()

    with open("./params2.json", 'r') as f:
        params = json.load(f)

    if check_params(args, params) == False:
        return

    sess = tf.Session()

    net = ESPCN(filters_size=params['filters_size'],
                channels=params['channels'],
                ratio=params['ratio'],
                batch_size=1,
                lr_size=params['lr_size'],
                edge=params['edge'])

    loss, images, labels = net.build_model()

    #rgb
    lr_image = tf.placeholder(tf.uint8)

    saver = tf.train.Saver()
    try:
        model_loaded = net.load(sess, saver, args.checkpoint)
    except:
        raise Exception(
            "Failed to load model, does the ratio in params2.json match the ratio you trained your checkpoint with?"
        )

    if model_loaded:
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")
        return

    def run_net(num):
        start0 = time.time()
        start1 = time.time()
        sr_image_y_data = sess.run(sr_image,
                                   feed_dict={lr_image: lr_image_batch})
        print('1', time.time() - start1)  #0.1967

        start2 = time.time()
        # pixel shuffle  b c r^2 h w ---> b c rh rw
        sr_image_y_data = shuffle(sr_image_y_data[0], params['ratio'])
        print('2', time.time() - start2)  #0.2775
        start3 = time.time()
        sr_image_ycbcr_data = misc.imresize(
            lr_image_ycbcr_data,
            params['ratio'] * np.array(lr_image_data.shape[0:2]), 'bicubic')
        edge = params['edge'] * params['ratio'] // 2

        sr_image_ycbcr_data = np.concatenate(
            (sr_image_y_data, sr_image_ycbcr_data[edge:-edge, edge:-edge,
                                                  1:3]),
            axis=2)
        print('3', time.time() - start3)  #0.0219
        start4 = time.time()
        #sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)
        print('4', time.time() - start4)  #3.7009   86.59%

        # start5 = time.time()
        # # res_image = cv2.cvtColor(sr_image_data, cv2.COLOR_BGR2YUV)
        print(type(sr_image_ycbcr_data))
        fw = open("result/res.yuv", 'ab')
        fw.write(sr_image_ycbcr_data)
        end = time.time()

        #print(sr_image_data.shape)
        # cv2.namedWindow('show_sr', 0)
        # cv2.imshow('show_sr', sr_image_ycbcr_data)
        # cv2.waitKey(50)
        # #print('5', time.time() - start5) #0.0767
        # cv2.imwrite(args.out_path + '_' + str(num) + '.jpg', sr_image_ycbcr_data)
        #misc.imsave(args.out_path + '_' + str(num) + '.png', sr_image_data)
        print("{:f} seconds".format(time.time() - start0))  #4.2739

    if args.lr_image[-3:] == 'yuv':
        width = 352
        height = 288
        # #lr_image_yuv_data = data[0][0]
        # #lr_image_yuv_data = misc.imread(args.lr_image)
        # #print(type(args.lr_image))
        # lr_image_yuv_data = yuv_import(args.lr_image, (height, width), 1, 0)
        # print(lr_image_yuv_data)
        # lr_image_y_data = lr_image_yuv_data
        # #print(lr_image_y_data.shape)
        # # lr_image_cb_data = lr_image_yuv_data[:, :, 1:2]
        # # lr_image_cr_data = lr_image_yuv_data[:, :, 2:3]
        # lr_image_batch = np.zeros((1,) + lr_image_y_data.shape)
        # lr_image_batch[0] = lr_image_y_data
        fp = open(args.lr_image, 'rb')
        framesize = height * width * 3 // 2  # 一帧图像所含的像素个数
        h_h = height // 2
        h_w = width // 2

        fp.seek(0, 2)  # 设置文件指针到文件流的尾部
        ps = fp.tell()  # 当前文件指针位置
        numfrm = ps // framesize  # 计算输出帧数
        fp.seek(framesize * 0, 0)

        for i in range(10 - 0):
            Yt = np.zeros(shape=(height, width), dtype='uint8', order='C')
            Ut = np.zeros(shape=(h_h, h_w), dtype='uint8', order='C')
            Vt = np.zeros(shape=(h_h, h_w), dtype='uint8', order='C')

            for m in range(height):
                for n in range(width):
                    Yt[m, n] = ord(fp.read(1))
            for m in range(h_h):
                for n in range(h_w):
                    Ut[m, n] = ord(fp.read(1))
            for m in range(h_h):
                for n in range(h_w):
                    Vt[m, n] = ord(fp.read(1))

            img = np.concatenate(
                (Yt.reshape(-1), Ut.reshape(-1), Vt.reshape(-1)))
            img = img.reshape((height * 3 // 2, width)).astype('uint8')

            # yuv2rgb
            bgr_img = cv2.cvtColor(img, cv2.COLOR_YUV2BGR_I420)
            # print(bgr_img)
            cv2.namedWindow('show', 0)
            cv2.imshow('show', bgr_img)
            #cv2.waitKey(10)
            # cv2.imwrite('result/007.jpg', bgr_img)
            # cv2.imwrite('yuv2bgr/%d.jpg' % (i + 1), bgr_img)
            # print("Extract frame %d " % (i + 1))
            lr_image_data = bgr_img
            # print(lr_image_data)
            lr_image_data = lr_image_data.reshape(
                lr_image_data.shape[0],
                lr_image_data.shape[1],
                3,
            )
            print(lr_image_data.shape)
            lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
            lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
            print('************************')
            print(lr_image_y_data.shape)
            lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
            lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
            lr_image_batch = np.zeros((1, ) + lr_image_y_data.shape)
            lr_image_batch[0] = lr_image_y_data

            sr_image = net.generate(lr_image)
            run_net(i)

    else:
        #imghdr.what(args.lr_image) == 'jpeg' or imghdr.what(args.lr_image) == 'png' or imghdr.what(args.lr_image) == 'bmp':
        lr_image_data = misc.imread(args.lr_image)
        #print(lr_image_data)
        lr_image_data = lr_image_data.reshape(
            lr_image_data.shape[0],
            lr_image_data.shape[1],
            3,
        )
        print(lr_image_data.shape)
        lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
        lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
        print('************************')
        print(lr_image_y_data.shape)
        lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
        lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
        lr_image_batch = np.zeros((1, ) + lr_image_y_data.shape)
        lr_image_batch[0] = lr_image_y_data
        sr_image = net.generate(lr_image)
Ejemplo n.º 5
0
def generate():
    args = get_arguments()

    with open("./params2.json", 'r') as f:
        params = json.load(f)

    if check_params(args, params) == False:
        return

    sess = tf.Session()

    net = ESPCN(filters_size=params['filters_size'],
                   channels=params['channels'],
                   ratio=params['ratio'],
                   batch_size=1,
                   lr_size=params['lr_size'],
                   edge=params['edge'])

    loss, images, labels = net.build_model()

    files = [f for f in os.listdir(args.lr_image_dir) if os.path.isfile(os.path.join(args.lr_image_dir, f))]

    saver = tf.train.Saver()
    if net.load(sess, saver, args.checkpoint):
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")
        return

    frame_range = (87, 10000)

    for fileName in files:
        try:
            ts = time()
            frame_cnt = int(fileName[5:10])
            if frame_cnt < frame_range[0] or frame_cnt > frame_range[1]:
                print('Ignoring frame ' + str(frame_cnt))
                continue
            else:
                print('start sr for frame ' + str(frame_cnt))

            input_file = os.path.join(args.lr_image_dir, fileName)
            output_file = os.path.join(args.out_path_dir, fileName)

            lr_image = tf.placeholder(tf.uint8)
            lr_image_data = misc.imread(input_file) # pip install pillow
            lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
            lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
            lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
            lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
            lr_image_batch = np.zeros((1,) + lr_image_y_data.shape)
            lr_image_batch[0] = lr_image_y_data
            print('preprocessed %d ms' % ((time()-ts)*1000))
            ts = time()

            sr_image = net.generate(lr_image)
            print('network generated %d ms' % ((time()-ts)*1000))
            ts = time()


            sr_image_y_data = sess.run(sr_image, feed_dict={lr_image: lr_image_batch})

            print('run %d ms' % ((time()-ts)*1000))
            ts = time()

            sr_image_y_data = shuffle(sr_image_y_data[0], args.ratio)
            sr_image_ycbcr_data = misc.imresize(lr_image_ycbcr_data,
                                            params['ratio'] * np.array(lr_image_data.shape[0:2]),
                                            'bicubic')

            edge = params['edge'] * params['ratio'] / 2

            sr_image_ycbcr_data = np.concatenate((sr_image_y_data, sr_image_ycbcr_data[edge:-edge,edge:-edge,1:3]), axis=2)
            print('mixed %d ms' % ((time()-ts)*1000))
            ts = time()
            sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)
            #sr_image_data = sr_image_ycbcr_data
            print('converted %d ms' % ((time()-ts)*1000))
            ts = time()

            misc.imsave(output_file, sr_image_data)
            print(output_file + ' generated %d ms' % ((time()-ts)*1000))
            ts = time()

            if args.hr_image_dir != None:
                hr_image_path = os.path.join(args.hr_image_dir, fileName)
                hr_image_data = misc.imread(hr_image_path)
                model_psnr = psnr(hr_image_data, sr_image_data, edge)
                print('PSNR of the model: {:.2f}dB'.format(model_psnr))

                sr_image_bicubic_data = misc.imresize(lr_image_data,
                                                params['ratio'] * np.array(lr_image_data.shape[0:2]),
                                                'bicubic')
                bicubic_path = os.path.join(args.out_path_dir, fileName + '_bicubic.png')
                misc.imsave(bicubic_path, sr_image_bicubic_data)
                bicubic_psnr = psnr(hr_image_data, sr_image_bicubic_data, 0)
                print('PSNR of Bicubic: {:.2f}dB'.format(bicubic_psnr))
        except IndexError:
            print('Index error caught')
        except IOError:
            print('Cannot identify image file: ' + fileName)
        except ValueError:
            print('Cannot parse file name: ' + fileName)
Ejemplo n.º 6
0
def generate():
    args = get_arguments()

    with open("./params.json", 'r') as f:
        params = json.load(f)

    if check_params(args, params) == False:
        return

    sess = tf.Session()

    net = ESPCN(filters_size=params['filters_size'],
                channels=params['channels'],
                ratio=params['ratio'],
                batch_size=1,
                lr_size=params['lr_size'],
                edge=params['edge'])

    loss, images, labels = net.build_model()

    lr_image = tf.placeholder(tf.uint8)
    lr_image_data = misc.imread(args.lr_image)
    lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
    lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
    lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
    lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
    lr_image_batch = np.zeros((1, ) + lr_image_y_data.shape)
    lr_image_batch[0] = lr_image_y_data

    sr_image = net.generate(lr_image)

    saver = tf.train.Saver()
    try:
        model_loaded = net.load(sess, saver, args.checkpoint)
    except:
        raise Exception(
            "Failed to load model, does the ratio in params.json match the ratio you trained your checkpoint with?"
        )

    if model_loaded:
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")
        return

    sr_image_y_data = sess.run(sr_image, feed_dict={lr_image: lr_image_batch})

    sr_image_y_data = shuffle(sr_image_y_data[0], params['ratio'])
    sr_image_ycbcr_data = misc.imresize(
        lr_image_ycbcr_data,
        params['ratio'] * np.array(lr_image_data.shape[0:2]), 'bicubic')

    edge = params['edge'] * params['ratio'] / 2

    sr_image_ycbcr_data = np.concatenate(
        (sr_image_y_data, sr_image_ycbcr_data[edge:-edge, edge:-edge, 1:3]),
        axis=2)
    sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)

    misc.imsave(args.out_path + '.png', sr_image_data)

    if args.hr_image != None:
        hr_image_data = misc.imread(args.hr_image)
        model_psnr = psnr(hr_image_data, sr_image_data, edge)
        print('PSNR of the model: {:.2f}dB'.format(model_psnr))

        sr_image_bicubic_data = misc.imresize(
            lr_image_data,
            params['ratio'] * np.array(lr_image_data.shape[0:2]), 'bicubic')
        misc.imsave(args.out_path + '_bicubic.png', sr_image_bicubic_data)
        bicubic_psnr = psnr(hr_image_data, sr_image_bicubic_data, 0)
        print('PSNR of Bicubic: {:.2f}dB'.format(bicubic_psnr))