Exemplo n.º 1
0
def func_enhance(dir_model_pre, QP, PreIndex_list, CmpIndex_list,
                 SubIndex_list):
    """Enhance PQFs or non-PQFs, record dpsnr, dssim and enhanced frames."""

    global enhanced_list, sum_dpsnr, sum_dssim

    tf.reset_default_graph()

    ### Defind enhancement process
    x1 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # previous
    x2 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # current
    x3 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # subsequent

    if QP in net1_list:
        is_training = tf.placeholder_with_default(False, shape=())

    x1to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x1, False)
    x3to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x3, True)

    if QP in net1_list:
        x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training)
    else:
        x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2)

    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:

        # Restore model
        model_path = os.path.join(dir_model_pre, "model_step2.ckpt-" + str(QP))
        saver.restore(sess, model_path)

        nfs = len(CmpIndex_list)

        sum_dpsnr_part = 0.0
        sum_dssim_part = 0.0

        for ite_frame in range(nfs):

            # Load frames
            pre_frame = y_import(CmpVideo_path, height, width, 1,
                                 PreIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0
            cmp_frame = y_import(CmpVideo_path, height, width, 1,
                                 CmpIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0
            sub_frame = y_import(CmpVideo_path, height, width, 1,
                                 SubIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0

            # if cmp frame is plane?
            if isplane(cmp_frame):
                continue

            # if PQF frames are plane?
            if isplane(pre_frame):
                pre_frame = np.copy(cmp_frame)
            if isplane(sub_frame):
                sub_frame = np.copy(cmp_frame)

            # Enhance
            if QP in net1_list:
                enhanced_frame = sess.run(x2_enhanced,
                                          feed_dict={
                                              x1: pre_frame,
                                              x2: cmp_frame,
                                              x3: sub_frame,
                                              is_training: False
                                          })
            else:
                enhanced_frame = sess.run(x2_enhanced,
                                          feed_dict={
                                              x1: pre_frame,
                                              x2: cmp_frame,
                                              x3: sub_frame
                                          })

            # Record for output video
            enhanced_list[CmpIndex_list[ite_frame]] = np.squeeze(
                enhanced_frame)

            # Evaluate and accumulate dpsnr
            raw_frame = np.squeeze(
                y_import(RawVideo_path, height, width, 1,
                         CmpIndex_list[ite_frame])) / 255.0
            cmp_frame = np.squeeze(cmp_frame)
            enhanced_frame = np.squeeze(enhanced_frame)

            raw_frame = np.float32(raw_frame)
            cmp_frame = np.float32(cmp_frame)

            psnr_ori = compare_psnr(cmp_frame, raw_frame, data_range=1.0)
            psnr_aft = compare_psnr(enhanced_frame, raw_frame, data_range=1.0)

            ssim_ori = compare_ssim(cmp_frame, raw_frame, data_range=1.0)
            ssim_aft = compare_ssim(enhanced_frame, raw_frame, data_range=1.0)

            sum_dpsnr_part += psnr_aft - psnr_ori
            sum_dssim_part += ssim_aft - ssim_ori

            print("%d | %d at QP = %d" % (ite_frame + 1, nfs, QP), end="\r")
        print("              ", end="\r")

        sum_dpsnr += sum_dpsnr_part
        sum_dssim += sum_dssim_part

        average_dpsnr = sum_dpsnr_part / nfs
        average_dssim = sum_dssim_part / nfs
        print("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d" %
              (average_dpsnr, average_dssim, nfs),
              flush=True)
        file_object.write("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d\n" %
                          (average_dpsnr, average_dssim, nfs))
        file_object.flush()
Exemplo n.º 2
0
def main_train():
    """Fine tune a model from step2 and continue training.

    Train and evaluate model."""
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # only show error and warning
    #os.environ['CUDA_VISIBLE_DEVICES'] = GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = "2"

    ### Defind a session
    config = tf.ConfigProto(
        allow_soft_placement=True
    )  # if GPU is not usable, then turn to CPU automatically
    sess = tf.Session(config=config)

    ### Set placeholder
    x1 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL])  # pre
    x2 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL])  # cmp
    x3 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL])  # sub
    x5 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL])  # raw

    if int(QP) in net1_list:
        is_training = tf.placeholder_with_default(
            False, shape=())  # for BN training/testing. default testing.

    PSNR_0 = cal_PSNR(x2, x5)  # PSNR before enhancement (cmp and raw)

    ### Motion compensation
    #x1to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x1, False)
    #x3to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x3, True)

    x1to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x1, False)
    x3to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x3, True)

    ### Flow loss
    FlowLoss_1 = cal_MSE(x1to2, x2)
    FlowLoss_2 = cal_MSE(x3to2, x2)
    flow_loss = FlowLoss_1 + FlowLoss_2

    ### Enhance cmp frames
    if int(QP) in net1_list:
        x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training)
    else:
        x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2)

    MSE = cal_MSE(x2_enhanced, x5)
    PSNR = cal_PSNR(x2_enhanced,
                    x5)  # PSNR after enhancement (enhanced and raw)
    delta_PSNR = PSNR - PSNR_0

    ### 2 kinds of loss for 2-step training
    OptimizeLoss_1 = flow_loss + ratio_small * MSE  # step1: the key is MC-subnet.
    OptimizeLoss_2 = ratio_small * flow_loss + MSE  # step2: the key is QE-subnet.

    ### Defind optimizer
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        Training_step1 = tf.train.AdamOptimizer(lr_ori).minimize(
            OptimizeLoss_1)
        Training_step2 = tf.train.AdamOptimizer(lr_ori).minimize(
            OptimizeLoss_2)

    ### TensorBoard
    tf.summary.scalar('PSNR improvement', delta_PSNR)
    #tf.summary.scalar('PSNR before enhancement', PSNR_0)
    #tf.summary.scalar('PSNR after enhancement', PSNR)
    tf.summary.scalar('MSE loss of motion compensation', flow_loss)
    #tf.summary.scalar('MSE loss of final quality enhancement', MSE)
    tf.summary.scalar('MSE loss for training step1 (mainly MC-subnet)',
                      OptimizeLoss_1)
    tf.summary.scalar('MSE loss for training step2 (mainly QE-subnet)',
                      OptimizeLoss_2)

    tf.summary.image('cmp', x2)
    tf.summary.image('enhanced', x2_enhanced)
    tf.summary.image('raw', x5)
    tf.summary.image('x1to2', x1to2)
    tf.summary.image('x3to2', x3to2)

    summary_writer = tf.summary.FileWriter(dir_model, sess.graph)
    summary_op = tf.summary.merge_all()

    saver = tf.train.Saver(max_to_keep=None)  # define a saver

    sess.run(tf.global_variables_initializer())  # initialize network variables

    ### Calculate the num of parameters
    num_params = 0
    for variable in tf.trainable_variables():
        shape = variable.get_shape()
        num_params += reduce(mul, [dim.value for dim in shape], 1)
    print("# num of parameters: %d #" % num_params)
    file_object.write("# num of parameters: %d #\n" % num_params)
    file_object.flush()

    ### Find all stacks then cal their number
    stack_name = os.path.join(dir_stack, "stack_tra_pre_*")
    num_TrainingStack = len(glob.glob(stack_name))
    stack_name = os.path.join(dir_stack, "stack_val_pre_*")
    num_ValidationStack = len(glob.glob(stack_name))

    ### Restore      ###注释掉了finetune训练,调用训练好的模型的代码
    # saver_res = tf.train.Saver()
    # saver_res.restore(sess, model_res_path)
    # print("successfully restore model %d!" % (int(res_index) + 1))
    # file_object.write("successfully restore model %d!\n" % (int(res_index) + 1))
    # file_object.flush()

    print("##### Start running! #####")

    num_TrainingBatch_count = 0

    ### Step 1: converge MC-subnet; Step 2: converge QE-subnet
    for ite_step in [1, 2]:

        if ite_step == 1:
            num_epoch = epoch_step1
        else:
            num_epoch = epoch_step2

        ### Epoch by Epoch
        for ite_epoch in range(num_epoch):

            ### Train stack by stack
            for ite_stack in range(num_TrainingStack):

                #pre_list, cmp_list, sub_list, raw_list = [], [], [], []
                #gc.collect()
                if ite_step == 1 and ite_epoch == 0 and ite_stack == 0:
                    pre_list, cmp_list, sub_list, raw_list = load_stack(
                        "tra", ite_stack)
                    gc.collect()
                num_batch = int(len(pre_list) / BATCH_SIZE)

                ### Batch by batch
                for ite_batch in range(num_batch):

                    print("\rstep %1d - epoch %2d/%2d - training stack %2d/%2d - batch %3d/%3d" % \
                        (ite_step, ite_epoch+1, num_epoch, ite_stack+1, num_TrainingStack, ite_batch+1, num_batch), end="")

                    start_index = ite_batch * BATCH_SIZE
                    next_start_index = (ite_batch + 1) * BATCH_SIZE

                    if ite_step == 1:
                        if int(QP) in net1_list:
                            Training_step1.run(
                                session=sess,
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index],
                                    is_training: True
                                })  # train
                        else:
                            Training_step1.run(
                                session=sess,
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index]
                                })  # train
                    else:
                        if int(QP) in net1_list:
                            Training_step2.run(
                                session=sess,
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index],
                                    is_training: True
                                })
                        else:
                            Training_step2.run(
                                session=sess,
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index]
                                })

                    # Update TensorBoard and print result
                    num_TrainingBatch_count += 1

                    if ((ite_batch + 1) == int(
                            num_batch / 2)) or ((ite_batch + 1) == num_batch):

                        if int(QP) in net1_list:
                            summary, delta_PSNR_batch, PSNR_0_batch, FlowLoss_batch, MSE_batch = sess.run(
                                [
                                    summary_op, delta_PSNR, PSNR_0, flow_loss,
                                    MSE
                                ],
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index],
                                    is_training: False
                                })
                        else:
                            summary, delta_PSNR_batch, PSNR_0_batch, FlowLoss_batch, MSE_batch = sess.run(
                                [
                                    summary_op, delta_PSNR, PSNR_0, flow_loss,
                                    MSE
                                ],
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index]
                                })

                        summary_writer.add_summary(summary,
                                                   num_TrainingBatch_count)
                        print("\rstep %1d - epoch %2d - imp PSNR: %.3f - ori PSNR: %.3f             " % \
                            (ite_step, ite_epoch + 1, delta_PSNR_batch, PSNR_0_batch))
                        file_object.write("step %1d - epoch %2d - imp PSNR: %.3f - ori PSNR: %.3f\n" % \
                            (ite_step, ite_epoch + 1, delta_PSNR_batch, PSNR_0_batch))
                        file_object.flush()

            ### Store the model of this epoch
            if ite_step == 1:
                CheckPoint_path = os.path.join(dir_model, "model_step1.ckpt")
            else:
                CheckPoint_path = os.path.join(dir_model, "model_step2.ckpt")
            saver.save(sess, CheckPoint_path, global_step=ite_epoch)

            sum_improved_PSNR = 0
            num_patch_count = 0

            ### Eval stack by stack, and report together for this epoch
            for ite_stack in range(num_ValidationStack):

                pre_list, cmp_list, sub_list, raw_list = [], [], [], []
                gc.collect()
                pre_list, cmp_list, sub_list, raw_list = load_stack(
                    "val", ite_stack)
                gc.collect()

                num_batch = int(len(pre_list) / BATCH_SIZE)

                ### Batch by batch
                for ite_batch in range(num_batch):

                    print("\rstep %1d - epoch %2d/%2d - validation stack %2d/%2d                " % \
                        (ite_step, ite_epoch+1, num_epoch, ite_stack+1, num_ValidationStack), end="")

                    start_index = ite_batch * BATCH_SIZE
                    next_start_index = (ite_batch + 1) * BATCH_SIZE

                    if int(QP) in net1_list:
                        delta_PSNR_batch = sess.run(
                            delta_PSNR,
                            feed_dict={
                                x1: pre_list[start_index:next_start_index],
                                x2: cmp_list[start_index:next_start_index],
                                x3: sub_list[start_index:next_start_index],
                                x5: raw_list[start_index:next_start_index],
                                is_training: False
                            })
                    else:
                        delta_PSNR_batch = sess.run(
                            delta_PSNR,
                            feed_dict={
                                x1: pre_list[start_index:next_start_index],
                                x2: cmp_list[start_index:next_start_index],
                                x3: sub_list[start_index:next_start_index],
                                x5: raw_list[start_index:next_start_index]
                            })

                    sum_improved_PSNR += delta_PSNR_batch * BATCH_SIZE
                    num_patch_count += BATCH_SIZE

            if num_patch_count != 0:
                print("\n### imp PSNR by model after step %1d - epoch %2d/%2d: %.3f ###\n" % \
                    (ite_step, ite_epoch+1, num_epoch, sum_improved_PSNR/num_patch_count))
                file_object.write("### imp PSNR by model after step %1d - epoch %2d/%2d: %.3f ###\n" % \
                    (ite_step, ite_epoch+1, num_epoch, sum_improved_PSNR/num_patch_count))
                file_object.flush()
Exemplo n.º 3
0
def func_enhance(dir_model_pre, QP, PreIndex_list, CmpIndex_list, SubIndex_list):
    """Enhance PQFs or non-PQFs, record dpsnr, dssim and enhanced frames."""
    #ywz
    global ywz_times

    global enhanced_list, sum_dpsnr, sum_dssim
    
    tf.reset_default_graph()

    ### Defind enhancement process
    x1 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # previous
    x2 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # current
    x3 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # subsequent
        
    if QP in net1_list:
        is_training = tf.placeholder_with_default(False, shape=())
    
    x1to2 = warp_img(BATCH_SIZE, x2, x1, False)
    x3to2 = warp_img(BATCH_SIZE, x2, x3, True)
    
    if QP in net1_list:
        x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training)
    else:
        x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2)
    
    saver = tf.train.Saver()
    
    with tf.Session(config = config) as sess:
        # ywz
        if ywz_times==0:
            options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        #

        # Restore model
        # model_path = os.path.join(dir_model_pre, "model_step2.ckpt-" + str(QP))
        model_path = os.path.join(dir_model_pre, "model_step2.ckpt-0")

        # saver.restore(sess, model_path)
        module_file = tf.train.latest_checkpoint(model_path)
        # with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if module_file is not None:
            saver.restore(sess, module_file)



        nfs = len(CmpIndex_list)
        
        sum_dpsnr_part = 0.0
        sum_dssim_part = 0.0
    
        for ite_frame in range(nfs):
          
            # Load frames
            pre_frame = y_import(CmpVideo_path, height, width, 1, PreIndex_list[ite_frame])[:,:,:,np.newaxis] / 255.0
            cmp_frame = y_import(CmpVideo_path, height, width, 1, CmpIndex_list[ite_frame])[:,:,:,np.newaxis] / 255.0
            sub_frame = y_import(CmpVideo_path, height, width, 1, SubIndex_list[ite_frame])[:,:,:,np.newaxis] / 255.0
            
            # if cmp frame is plane?
            if isplane(cmp_frame):
                continue

            # if PQF frames are plane?
            if isplane(pre_frame):
                 pre_frame = np.copy(cmp_frame)
            if isplane(sub_frame):
                 sub_frame = np.copy(cmp_frame)


            # ywz
            if ywz_times==0:
                run_metadata = tf.RunMetadata()
                # Enhance
                if QP in net1_list:
                    enhanced_frame = sess.run(x2_enhanced,options = options, feed_dict={x1:pre_frame, x2:cmp_frame, x3:sub_frame, is_training:False}, run_metadata=run_metadata)
                else:
                    enhanced_frame = sess.run(x2_enhanced,options = options, feed_dict={x1:pre_frame, x2:cmp_frame, x3:sub_frame}, run_metadata=run_metadata)

            else:
                # Enhance
                if QP in net1_list:
                    enhanced_frame = sess.run(x2_enhanced, feed_dict={x1: pre_frame, x2: cmp_frame, x3: sub_frame,is_training: False})
                else:
                    enhanced_frame = sess.run(x2_enhanced, feed_dict={x1: pre_frame, x2: cmp_frame, x3: sub_frame})

            #
            # ywz
            # timeline record
            # Create the Timeline object, and write it to a json file
            if ywz_times==0:
                fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                chrome_trace = fetched_timeline.generate_chrome_trace_format()
                with open("timeline/timeline_" + str(ite_frame) + ".json", 'w') as f:
                    f.write(chrome_trace)
            #


            # Record for output video
            enhanced_list[CmpIndex_list[ite_frame]] = np.squeeze(enhanced_frame)
            
            # Evaluate and accumulate dpsnr
            raw_frame = np.squeeze(y_import(RawVideo_path, height, width, 1, CmpIndex_list[ite_frame])) / 255.0
            cmp_frame = np.squeeze(cmp_frame)
            enhanced_frame = np.squeeze(enhanced_frame)
            
            raw_frame = np.float32(raw_frame)
            cmp_frame = np.float32(cmp_frame)
            
            psnr_ori = compare_psnr(cmp_frame, raw_frame, data_range=1.0)
            psnr_aft = compare_psnr(enhanced_frame, raw_frame, data_range=1.0)

            ssim_ori = compare_ssim(cmp_frame, raw_frame, data_range=1.0)
            ssim_aft = compare_ssim(enhanced_frame, raw_frame, data_range=1.0)
            
            sum_dpsnr_part += psnr_aft - psnr_ori
            sum_dssim_part += ssim_aft - ssim_ori

            print("\r %d | %d at QP = %d" % (ite_frame + 1, nfs, QP), end="")

            #ywz
            ywz_times += 1

        print("              ", end="\r")
        
        sum_dpsnr += sum_dpsnr_part
        sum_dssim += sum_dssim_part
        
        average_dpsnr = sum_dpsnr_part / nfs
        average_dssim = sum_dssim_part / nfs
        print("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d" % (average_dpsnr, average_dssim, nfs), flush=True)
        file_object.write("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d\n" % (average_dpsnr, average_dssim, nfs))
        file_object.flush()
Exemplo n.º 4
0
def enhance(QP,input,Non_PQF_indices,pre_PQF_indices,sub_PQF_indices):

    # Recommended MF-CNN model
    if QP == 22:
        model_index = 1230000
    elif QP == 27:
        model_index = 275000
    elif QP == 32:
        model_index = 1227500
    elif QP == 37:
        model_index = 1122500
    elif QP == 42:
        model_index = 1250000

    model_path = "./Model_MFCNN/QP" + str(QP) + "/model.ckpt-" + str(model_index)

    # video information
    nfs, height, width = input.shape
    nfs = len(Non_PQF_indices)

    input = input[:,:,:,np.newaxis]
    input = input / 255.0

    enhanced_frames = np.zeros([nfs,height,width])

    with tf.Graph().as_default() as g:

        x1 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # previous
        x2 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # current
        x3 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # subsequent
        is_training = tf.placeholder_with_default(False, shape=())

        x1to2 = flow.warp_img(BATCH_SIZE, x2, x1, False)
        x3to2 = flow.warp_img(BATCH_SIZE, x2, x3, True)

        if (QP == 37) or (QP == 42):
            x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training)
        elif (QP == 22) or (QP == 27) or (QP == 32):
            x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2)

        # restore vars above
        saver = tf.train.Saver()

        with tf.Session(config = config) as sess:

            # restore model
            saver.restore(sess,model_path)

            # enhance
            start_time = time.time()
            for ite in range(nfs):
                x1_feed = input[pre_PQF_indices[ite]:pre_PQF_indices[ite]+1,:,:,:]
                x2_feed = input[Non_PQF_indices[ite]:Non_PQF_indices[ite]+1,:,:,:]
                x3_feed = input[sub_PQF_indices[ite]:sub_PQF_indices[ite]+1,:,:,:]

                x2_enhanced_frame = sess.run(x2_enhanced, feed_dict={x1: x1_feed, x2: x2_feed, x3: x3_feed, is_training:False})
                enhanced_frames[ite] = np.squeeze(x2_enhanced_frame)

                print("\r"+str(ite+1)+" | "+str(nfs), end="", flush=True)

            end_time = time.time()
            average_fps = nfs / (end_time - start_time)
            print("")

    return enhanced_frames, average_fps