Ejemplo n.º 1
0
def main_train():
    """Fine tune a model from step2 and continue training.

    Train and evaluate model."""
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # only show error and warning
    #os.environ['CUDA_VISIBLE_DEVICES'] = GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = "2"

    ### Defind a session
    config = tf.ConfigProto(
        allow_soft_placement=True
    )  # if GPU is not usable, then turn to CPU automatically
    sess = tf.Session(config=config)

    ### Set placeholder
    x1 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL])  # pre
    x2 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL])  # cmp
    x3 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL])  # sub
    x5 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL])  # raw

    if int(QP) in net1_list:
        is_training = tf.placeholder_with_default(
            False, shape=())  # for BN training/testing. default testing.

    PSNR_0 = cal_PSNR(x2, x5)  # PSNR before enhancement (cmp and raw)

    ### Motion compensation
    #x1to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x1, False)
    #x3to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x3, True)

    x1to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x1, False)
    x3to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x3, True)

    ### Flow loss
    FlowLoss_1 = cal_MSE(x1to2, x2)
    FlowLoss_2 = cal_MSE(x3to2, x2)
    flow_loss = FlowLoss_1 + FlowLoss_2

    ### Enhance cmp frames
    if int(QP) in net1_list:
        x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training)
    else:
        x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2)

    MSE = cal_MSE(x2_enhanced, x5)
    PSNR = cal_PSNR(x2_enhanced,
                    x5)  # PSNR after enhancement (enhanced and raw)
    delta_PSNR = PSNR - PSNR_0

    ### 2 kinds of loss for 2-step training
    OptimizeLoss_1 = flow_loss + ratio_small * MSE  # step1: the key is MC-subnet.
    OptimizeLoss_2 = ratio_small * flow_loss + MSE  # step2: the key is QE-subnet.

    ### Defind optimizer
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        Training_step1 = tf.train.AdamOptimizer(lr_ori).minimize(
            OptimizeLoss_1)
        Training_step2 = tf.train.AdamOptimizer(lr_ori).minimize(
            OptimizeLoss_2)

    ### TensorBoard
    tf.summary.scalar('PSNR improvement', delta_PSNR)
    #tf.summary.scalar('PSNR before enhancement', PSNR_0)
    #tf.summary.scalar('PSNR after enhancement', PSNR)
    tf.summary.scalar('MSE loss of motion compensation', flow_loss)
    #tf.summary.scalar('MSE loss of final quality enhancement', MSE)
    tf.summary.scalar('MSE loss for training step1 (mainly MC-subnet)',
                      OptimizeLoss_1)
    tf.summary.scalar('MSE loss for training step2 (mainly QE-subnet)',
                      OptimizeLoss_2)

    tf.summary.image('cmp', x2)
    tf.summary.image('enhanced', x2_enhanced)
    tf.summary.image('raw', x5)
    tf.summary.image('x1to2', x1to2)
    tf.summary.image('x3to2', x3to2)

    summary_writer = tf.summary.FileWriter(dir_model, sess.graph)
    summary_op = tf.summary.merge_all()

    saver = tf.train.Saver(max_to_keep=None)  # define a saver

    sess.run(tf.global_variables_initializer())  # initialize network variables

    ### Calculate the num of parameters
    num_params = 0
    for variable in tf.trainable_variables():
        shape = variable.get_shape()
        num_params += reduce(mul, [dim.value for dim in shape], 1)
    print("# num of parameters: %d #" % num_params)
    file_object.write("# num of parameters: %d #\n" % num_params)
    file_object.flush()

    ### Find all stacks then cal their number
    stack_name = os.path.join(dir_stack, "stack_tra_pre_*")
    num_TrainingStack = len(glob.glob(stack_name))
    stack_name = os.path.join(dir_stack, "stack_val_pre_*")
    num_ValidationStack = len(glob.glob(stack_name))

    ### Restore      ###注释掉了finetune训练,调用训练好的模型的代码
    # saver_res = tf.train.Saver()
    # saver_res.restore(sess, model_res_path)
    # print("successfully restore model %d!" % (int(res_index) + 1))
    # file_object.write("successfully restore model %d!\n" % (int(res_index) + 1))
    # file_object.flush()

    print("##### Start running! #####")

    num_TrainingBatch_count = 0

    ### Step 1: converge MC-subnet; Step 2: converge QE-subnet
    for ite_step in [1, 2]:

        if ite_step == 1:
            num_epoch = epoch_step1
        else:
            num_epoch = epoch_step2

        ### Epoch by Epoch
        for ite_epoch in range(num_epoch):

            ### Train stack by stack
            for ite_stack in range(num_TrainingStack):

                #pre_list, cmp_list, sub_list, raw_list = [], [], [], []
                #gc.collect()
                if ite_step == 1 and ite_epoch == 0 and ite_stack == 0:
                    pre_list, cmp_list, sub_list, raw_list = load_stack(
                        "tra", ite_stack)
                    gc.collect()
                num_batch = int(len(pre_list) / BATCH_SIZE)

                ### Batch by batch
                for ite_batch in range(num_batch):

                    print("\rstep %1d - epoch %2d/%2d - training stack %2d/%2d - batch %3d/%3d" % \
                        (ite_step, ite_epoch+1, num_epoch, ite_stack+1, num_TrainingStack, ite_batch+1, num_batch), end="")

                    start_index = ite_batch * BATCH_SIZE
                    next_start_index = (ite_batch + 1) * BATCH_SIZE

                    if ite_step == 1:
                        if int(QP) in net1_list:
                            Training_step1.run(
                                session=sess,
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index],
                                    is_training: True
                                })  # train
                        else:
                            Training_step1.run(
                                session=sess,
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index]
                                })  # train
                    else:
                        if int(QP) in net1_list:
                            Training_step2.run(
                                session=sess,
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index],
                                    is_training: True
                                })
                        else:
                            Training_step2.run(
                                session=sess,
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index]
                                })

                    # Update TensorBoard and print result
                    num_TrainingBatch_count += 1

                    if ((ite_batch + 1) == int(
                            num_batch / 2)) or ((ite_batch + 1) == num_batch):

                        if int(QP) in net1_list:
                            summary, delta_PSNR_batch, PSNR_0_batch, FlowLoss_batch, MSE_batch = sess.run(
                                [
                                    summary_op, delta_PSNR, PSNR_0, flow_loss,
                                    MSE
                                ],
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index],
                                    is_training: False
                                })
                        else:
                            summary, delta_PSNR_batch, PSNR_0_batch, FlowLoss_batch, MSE_batch = sess.run(
                                [
                                    summary_op, delta_PSNR, PSNR_0, flow_loss,
                                    MSE
                                ],
                                feed_dict={
                                    x1: pre_list[start_index:next_start_index],
                                    x2: cmp_list[start_index:next_start_index],
                                    x3: sub_list[start_index:next_start_index],
                                    x5: raw_list[start_index:next_start_index]
                                })

                        summary_writer.add_summary(summary,
                                                   num_TrainingBatch_count)
                        print("\rstep %1d - epoch %2d - imp PSNR: %.3f - ori PSNR: %.3f             " % \
                            (ite_step, ite_epoch + 1, delta_PSNR_batch, PSNR_0_batch))
                        file_object.write("step %1d - epoch %2d - imp PSNR: %.3f - ori PSNR: %.3f\n" % \
                            (ite_step, ite_epoch + 1, delta_PSNR_batch, PSNR_0_batch))
                        file_object.flush()

            ### Store the model of this epoch
            if ite_step == 1:
                CheckPoint_path = os.path.join(dir_model, "model_step1.ckpt")
            else:
                CheckPoint_path = os.path.join(dir_model, "model_step2.ckpt")
            saver.save(sess, CheckPoint_path, global_step=ite_epoch)

            sum_improved_PSNR = 0
            num_patch_count = 0

            ### Eval stack by stack, and report together for this epoch
            for ite_stack in range(num_ValidationStack):

                pre_list, cmp_list, sub_list, raw_list = [], [], [], []
                gc.collect()
                pre_list, cmp_list, sub_list, raw_list = load_stack(
                    "val", ite_stack)
                gc.collect()

                num_batch = int(len(pre_list) / BATCH_SIZE)

                ### Batch by batch
                for ite_batch in range(num_batch):

                    print("\rstep %1d - epoch %2d/%2d - validation stack %2d/%2d                " % \
                        (ite_step, ite_epoch+1, num_epoch, ite_stack+1, num_ValidationStack), end="")

                    start_index = ite_batch * BATCH_SIZE
                    next_start_index = (ite_batch + 1) * BATCH_SIZE

                    if int(QP) in net1_list:
                        delta_PSNR_batch = sess.run(
                            delta_PSNR,
                            feed_dict={
                                x1: pre_list[start_index:next_start_index],
                                x2: cmp_list[start_index:next_start_index],
                                x3: sub_list[start_index:next_start_index],
                                x5: raw_list[start_index:next_start_index],
                                is_training: False
                            })
                    else:
                        delta_PSNR_batch = sess.run(
                            delta_PSNR,
                            feed_dict={
                                x1: pre_list[start_index:next_start_index],
                                x2: cmp_list[start_index:next_start_index],
                                x3: sub_list[start_index:next_start_index],
                                x5: raw_list[start_index:next_start_index]
                            })

                    sum_improved_PSNR += delta_PSNR_batch * BATCH_SIZE
                    num_patch_count += BATCH_SIZE

            if num_patch_count != 0:
                print("\n### imp PSNR by model after step %1d - epoch %2d/%2d: %.3f ###\n" % \
                    (ite_step, ite_epoch+1, num_epoch, sum_improved_PSNR/num_patch_count))
                file_object.write("### imp PSNR by model after step %1d - epoch %2d/%2d: %.3f ###\n" % \
                    (ite_step, ite_epoch+1, num_epoch, sum_improved_PSNR/num_patch_count))
                file_object.flush()
Ejemplo n.º 2
0
def func_enhance(dir_model_pre, QP, PreIndex_list, CmpIndex_list,
                 SubIndex_list):
    """Enhance PQFs or non-PQFs, record dpsnr, dssim and enhanced frames."""

    global enhanced_list, sum_dpsnr, sum_dssim

    tf.reset_default_graph()

    ### Defind enhancement process
    x1 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # previous
    x2 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # current
    x3 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # subsequent

    if QP in net1_list:
        is_training = tf.placeholder_with_default(False, shape=())

    x1to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x1, False)
    x3to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x3, True)

    if QP in net1_list:
        x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training)
    else:
        x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2)

    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:

        # Restore model
        model_path = os.path.join(dir_model_pre, "model_step2.ckpt-" + str(QP))
        saver.restore(sess, model_path)

        nfs = len(CmpIndex_list)

        sum_dpsnr_part = 0.0
        sum_dssim_part = 0.0

        for ite_frame in range(nfs):

            # Load frames
            pre_frame = y_import(CmpVideo_path, height, width, 1,
                                 PreIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0
            cmp_frame = y_import(CmpVideo_path, height, width, 1,
                                 CmpIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0
            sub_frame = y_import(CmpVideo_path, height, width, 1,
                                 SubIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0

            # if cmp frame is plane?
            if isplane(cmp_frame):
                continue

            # if PQF frames are plane?
            if isplane(pre_frame):
                pre_frame = np.copy(cmp_frame)
            if isplane(sub_frame):
                sub_frame = np.copy(cmp_frame)

            # Enhance
            if QP in net1_list:
                enhanced_frame = sess.run(x2_enhanced,
                                          feed_dict={
                                              x1: pre_frame,
                                              x2: cmp_frame,
                                              x3: sub_frame,
                                              is_training: False
                                          })
            else:
                enhanced_frame = sess.run(x2_enhanced,
                                          feed_dict={
                                              x1: pre_frame,
                                              x2: cmp_frame,
                                              x3: sub_frame
                                          })

            # Record for output video
            enhanced_list[CmpIndex_list[ite_frame]] = np.squeeze(
                enhanced_frame)

            # Evaluate and accumulate dpsnr
            raw_frame = np.squeeze(
                y_import(RawVideo_path, height, width, 1,
                         CmpIndex_list[ite_frame])) / 255.0
            cmp_frame = np.squeeze(cmp_frame)
            enhanced_frame = np.squeeze(enhanced_frame)

            raw_frame = np.float32(raw_frame)
            cmp_frame = np.float32(cmp_frame)

            psnr_ori = compare_psnr(cmp_frame, raw_frame, data_range=1.0)
            psnr_aft = compare_psnr(enhanced_frame, raw_frame, data_range=1.0)

            ssim_ori = compare_ssim(cmp_frame, raw_frame, data_range=1.0)
            ssim_aft = compare_ssim(enhanced_frame, raw_frame, data_range=1.0)

            sum_dpsnr_part += psnr_aft - psnr_ori
            sum_dssim_part += ssim_aft - ssim_ori

            print("%d | %d at QP = %d" % (ite_frame + 1, nfs, QP), end="\r")
        print("              ", end="\r")

        sum_dpsnr += sum_dpsnr_part
        sum_dssim += sum_dssim_part

        average_dpsnr = sum_dpsnr_part / nfs
        average_dssim = sum_dssim_part / nfs
        print("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d" %
              (average_dpsnr, average_dssim, nfs),
              flush=True)
        file_object.write("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d\n" %
                          (average_dpsnr, average_dssim, nfs))
        file_object.flush()
Ejemplo n.º 3
0
def func_enhance(dir_model_pre, QP, PreIndex_list, CmpIndex_list,
                 SubIndex_list):
    """Enhance PQFs or non-PQFs, record dpsnr, dssim and enhanced frames."""
    #ywz
    global ywz_times

    global enhanced_list, sum_dpsnr, sum_dssim

    tf.reset_default_graph()

    ### Defind enhancement process
    x1 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # previous
    x2 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # current
    x3 = tf.placeholder(tf.float32,
                        [BATCH_SIZE, height, width, CHANNEL])  # subsequent

    if QP in net1_list:
        is_training = tf.placeholder_with_default(False, shape=())

    x1to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x1, False)
    x3to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x3, True)

    if QP in net1_list:
        x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training)
    else:
        x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2)

    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        # ywz
        if ywz_times == 0:
            options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        #

        # Restore model
        model_path = os.path.join(dir_model_pre, "model_step2.ckpt-0")
        # saver.restore(sess, model_path)
        module_file = tf.train.latest_checkpoint(model_path)
        # with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if module_file is not None:
            saver.restore(sess, module_file)

        nfs = len(CmpIndex_list)

        sum_dpsnr_part = 0.0
        sum_dssim_part = 0.0

        for ite_frame in range(nfs):

            # Load frames
            pre_frame = y_import(CmpVideo_path, height, width, 1,
                                 PreIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0
            cmp_frame = y_import(CmpVideo_path, height, width, 1,
                                 CmpIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0
            sub_frame = y_import(CmpVideo_path, height, width, 1,
                                 SubIndex_list[ite_frame])[:, :, :,
                                                           np.newaxis] / 255.0

            # if cmp frame is plane?
            if isplane(cmp_frame):
                continue

            # if PQF frames are plane?
            if isplane(pre_frame):
                pre_frame = np.copy(cmp_frame)
            if isplane(sub_frame):
                sub_frame = np.copy(cmp_frame)

            # ywz
            if ywz_times == 0:
                run_metadata = tf.RunMetadata()
                # Enhance
                if QP in net1_list:
                    enhanced_frame = sess.run(x2_enhanced,
                                              options=options,
                                              feed_dict={
                                                  x1: pre_frame,
                                                  x2: cmp_frame,
                                                  x3: sub_frame,
                                                  is_training: False
                                              },
                                              run_metadata=run_metadata)
                else:
                    enhanced_frame = sess.run(x2_enhanced,
                                              options=options,
                                              feed_dict={
                                                  x1: pre_frame,
                                                  x2: cmp_frame,
                                                  x3: sub_frame
                                              },
                                              run_metadata=run_metadata)

            else:
                # Enhance
                if QP in net1_list:
                    enhanced_frame = sess.run(x2_enhanced,
                                              feed_dict={
                                                  x1: pre_frame,
                                                  x2: cmp_frame,
                                                  x3: sub_frame,
                                                  is_training: False
                                              })
                else:
                    enhanced_frame = sess.run(x2_enhanced,
                                              feed_dict={
                                                  x1: pre_frame,
                                                  x2: cmp_frame,
                                                  x3: sub_frame
                                              })

            #
            # ywz
            # timeline record
            # Create the Timeline object, and write it to a json file
            if ywz_times == 0:
                fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                chrome_trace = fetched_timeline.generate_chrome_trace_format()
                with open("timeline/timeline_" + str(ite_frame) + ".json",
                          'w') as f:
                    f.write(chrome_trace)
            #

            # Record for output video
            enhanced_list[CmpIndex_list[ite_frame]] = np.squeeze(
                enhanced_frame)

            # Evaluate and accumulate dpsnr
            raw_frame = np.squeeze(
                y_import(RawVideo_path, height, width, 1,
                         CmpIndex_list[ite_frame])) / 255.0
            cmp_frame = np.squeeze(cmp_frame)
            enhanced_frame = np.squeeze(enhanced_frame)

            raw_frame = np.float32(raw_frame)
            cmp_frame = np.float32(cmp_frame)

            psnr_ori = compare_psnr(cmp_frame, raw_frame, data_range=1.0)
            psnr_aft = compare_psnr(enhanced_frame, raw_frame, data_range=1.0)

            ssim_ori = compare_ssim(cmp_frame, raw_frame, data_range=1.0)
            ssim_aft = compare_ssim(enhanced_frame, raw_frame, data_range=1.0)

            sum_dpsnr_part += psnr_aft - psnr_ori
            sum_dssim_part += ssim_aft - ssim_ori

            print("\r %d | %d at QP = %d" % (ite_frame + 1, nfs, QP), end="")

            #ywz 只记一次
            # ywz_times += 1

        print("              ", end="\r")

        sum_dpsnr += sum_dpsnr_part
        sum_dssim += sum_dssim_part

        average_dpsnr = sum_dpsnr_part / nfs
        average_dssim = sum_dssim_part / nfs
        print("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d" %
              (average_dpsnr, average_dssim, nfs),
              flush=True)
        file_object.write("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d\n" %
                          (average_dpsnr, average_dssim, nfs))
        file_object.flush()