def main_train(): """Fine tune a model from step2 and continue training. Train and evaluate model.""" os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # only show error and warning #os.environ['CUDA_VISIBLE_DEVICES'] = GPU os.environ['CUDA_VISIBLE_DEVICES'] = "2" ### Defind a session config = tf.ConfigProto( allow_soft_placement=True ) # if GPU is not usable, then turn to CPU automatically sess = tf.Session(config=config) ### Set placeholder x1 = tf.placeholder(tf.float32, [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL]) # pre x2 = tf.placeholder(tf.float32, [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL]) # cmp x3 = tf.placeholder(tf.float32, [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL]) # sub x5 = tf.placeholder(tf.float32, [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL]) # raw if int(QP) in net1_list: is_training = tf.placeholder_with_default( False, shape=()) # for BN training/testing. default testing. PSNR_0 = cal_PSNR(x2, x5) # PSNR before enhancement (cmp and raw) ### Motion compensation #x1to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x1, False) #x3to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x3, True) x1to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x1, False) x3to2 = net_MFCNN.warp_img(tf.shape(x2)[0], x2, x3, True) ### Flow loss FlowLoss_1 = cal_MSE(x1to2, x2) FlowLoss_2 = cal_MSE(x3to2, x2) flow_loss = FlowLoss_1 + FlowLoss_2 ### Enhance cmp frames if int(QP) in net1_list: x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training) else: x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2) MSE = cal_MSE(x2_enhanced, x5) PSNR = cal_PSNR(x2_enhanced, x5) # PSNR after enhancement (enhanced and raw) delta_PSNR = PSNR - PSNR_0 ### 2 kinds of loss for 2-step training OptimizeLoss_1 = flow_loss + ratio_small * MSE # step1: the key is MC-subnet. OptimizeLoss_2 = ratio_small * flow_loss + MSE # step2: the key is QE-subnet. ### Defind optimizer update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): Training_step1 = tf.train.AdamOptimizer(lr_ori).minimize( OptimizeLoss_1) Training_step2 = tf.train.AdamOptimizer(lr_ori).minimize( OptimizeLoss_2) ### TensorBoard tf.summary.scalar('PSNR improvement', delta_PSNR) #tf.summary.scalar('PSNR before enhancement', PSNR_0) #tf.summary.scalar('PSNR after enhancement', PSNR) tf.summary.scalar('MSE loss of motion compensation', flow_loss) #tf.summary.scalar('MSE loss of final quality enhancement', MSE) tf.summary.scalar('MSE loss for training step1 (mainly MC-subnet)', OptimizeLoss_1) tf.summary.scalar('MSE loss for training step2 (mainly QE-subnet)', OptimizeLoss_2) tf.summary.image('cmp', x2) tf.summary.image('enhanced', x2_enhanced) tf.summary.image('raw', x5) tf.summary.image('x1to2', x1to2) tf.summary.image('x3to2', x3to2) summary_writer = tf.summary.FileWriter(dir_model, sess.graph) summary_op = tf.summary.merge_all() saver = tf.train.Saver(max_to_keep=None) # define a saver sess.run(tf.global_variables_initializer()) # initialize network variables ### Calculate the num of parameters num_params = 0 for variable in tf.trainable_variables(): shape = variable.get_shape() num_params += reduce(mul, [dim.value for dim in shape], 1) print("# num of parameters: %d #" % num_params) file_object.write("# num of parameters: %d #\n" % num_params) file_object.flush() ### Find all stacks then cal their number stack_name = os.path.join(dir_stack, "stack_tra_pre_*") num_TrainingStack = len(glob.glob(stack_name)) stack_name = os.path.join(dir_stack, "stack_val_pre_*") num_ValidationStack = len(glob.glob(stack_name)) ### Restore ###注释掉了finetune训练,调用训练好的模型的代码 # saver_res = tf.train.Saver() # saver_res.restore(sess, model_res_path) # print("successfully restore model %d!" % (int(res_index) + 1)) # file_object.write("successfully restore model %d!\n" % (int(res_index) + 1)) # file_object.flush() print("##### Start running! #####") num_TrainingBatch_count = 0 ### Step 1: converge MC-subnet; Step 2: converge QE-subnet for ite_step in [1, 2]: if ite_step == 1: num_epoch = epoch_step1 else: num_epoch = epoch_step2 ### Epoch by Epoch for ite_epoch in range(num_epoch): ### Train stack by stack for ite_stack in range(num_TrainingStack): #pre_list, cmp_list, sub_list, raw_list = [], [], [], [] #gc.collect() if ite_step == 1 and ite_epoch == 0 and ite_stack == 0: pre_list, cmp_list, sub_list, raw_list = load_stack( "tra", ite_stack) gc.collect() num_batch = int(len(pre_list) / BATCH_SIZE) ### Batch by batch for ite_batch in range(num_batch): print("\rstep %1d - epoch %2d/%2d - training stack %2d/%2d - batch %3d/%3d" % \ (ite_step, ite_epoch+1, num_epoch, ite_stack+1, num_TrainingStack, ite_batch+1, num_batch), end="") start_index = ite_batch * BATCH_SIZE next_start_index = (ite_batch + 1) * BATCH_SIZE if ite_step == 1: if int(QP) in net1_list: Training_step1.run( session=sess, feed_dict={ x1: pre_list[start_index:next_start_index], x2: cmp_list[start_index:next_start_index], x3: sub_list[start_index:next_start_index], x5: raw_list[start_index:next_start_index], is_training: True }) # train else: Training_step1.run( session=sess, feed_dict={ x1: pre_list[start_index:next_start_index], x2: cmp_list[start_index:next_start_index], x3: sub_list[start_index:next_start_index], x5: raw_list[start_index:next_start_index] }) # train else: if int(QP) in net1_list: Training_step2.run( session=sess, feed_dict={ x1: pre_list[start_index:next_start_index], x2: cmp_list[start_index:next_start_index], x3: sub_list[start_index:next_start_index], x5: raw_list[start_index:next_start_index], is_training: True }) else: Training_step2.run( session=sess, feed_dict={ x1: pre_list[start_index:next_start_index], x2: cmp_list[start_index:next_start_index], x3: sub_list[start_index:next_start_index], x5: raw_list[start_index:next_start_index] }) # Update TensorBoard and print result num_TrainingBatch_count += 1 if ((ite_batch + 1) == int( num_batch / 2)) or ((ite_batch + 1) == num_batch): if int(QP) in net1_list: summary, delta_PSNR_batch, PSNR_0_batch, FlowLoss_batch, MSE_batch = sess.run( [ summary_op, delta_PSNR, PSNR_0, flow_loss, MSE ], feed_dict={ x1: pre_list[start_index:next_start_index], x2: cmp_list[start_index:next_start_index], x3: sub_list[start_index:next_start_index], x5: raw_list[start_index:next_start_index], is_training: False }) else: summary, delta_PSNR_batch, PSNR_0_batch, FlowLoss_batch, MSE_batch = sess.run( [ summary_op, delta_PSNR, PSNR_0, flow_loss, MSE ], feed_dict={ x1: pre_list[start_index:next_start_index], x2: cmp_list[start_index:next_start_index], x3: sub_list[start_index:next_start_index], x5: raw_list[start_index:next_start_index] }) summary_writer.add_summary(summary, num_TrainingBatch_count) print("\rstep %1d - epoch %2d - imp PSNR: %.3f - ori PSNR: %.3f " % \ (ite_step, ite_epoch + 1, delta_PSNR_batch, PSNR_0_batch)) file_object.write("step %1d - epoch %2d - imp PSNR: %.3f - ori PSNR: %.3f\n" % \ (ite_step, ite_epoch + 1, delta_PSNR_batch, PSNR_0_batch)) file_object.flush() ### Store the model of this epoch if ite_step == 1: CheckPoint_path = os.path.join(dir_model, "model_step1.ckpt") else: CheckPoint_path = os.path.join(dir_model, "model_step2.ckpt") saver.save(sess, CheckPoint_path, global_step=ite_epoch) sum_improved_PSNR = 0 num_patch_count = 0 ### Eval stack by stack, and report together for this epoch for ite_stack in range(num_ValidationStack): pre_list, cmp_list, sub_list, raw_list = [], [], [], [] gc.collect() pre_list, cmp_list, sub_list, raw_list = load_stack( "val", ite_stack) gc.collect() num_batch = int(len(pre_list) / BATCH_SIZE) ### Batch by batch for ite_batch in range(num_batch): print("\rstep %1d - epoch %2d/%2d - validation stack %2d/%2d " % \ (ite_step, ite_epoch+1, num_epoch, ite_stack+1, num_ValidationStack), end="") start_index = ite_batch * BATCH_SIZE next_start_index = (ite_batch + 1) * BATCH_SIZE if int(QP) in net1_list: delta_PSNR_batch = sess.run( delta_PSNR, feed_dict={ x1: pre_list[start_index:next_start_index], x2: cmp_list[start_index:next_start_index], x3: sub_list[start_index:next_start_index], x5: raw_list[start_index:next_start_index], is_training: False }) else: delta_PSNR_batch = sess.run( delta_PSNR, feed_dict={ x1: pre_list[start_index:next_start_index], x2: cmp_list[start_index:next_start_index], x3: sub_list[start_index:next_start_index], x5: raw_list[start_index:next_start_index] }) sum_improved_PSNR += delta_PSNR_batch * BATCH_SIZE num_patch_count += BATCH_SIZE if num_patch_count != 0: print("\n### imp PSNR by model after step %1d - epoch %2d/%2d: %.3f ###\n" % \ (ite_step, ite_epoch+1, num_epoch, sum_improved_PSNR/num_patch_count)) file_object.write("### imp PSNR by model after step %1d - epoch %2d/%2d: %.3f ###\n" % \ (ite_step, ite_epoch+1, num_epoch, sum_improved_PSNR/num_patch_count)) file_object.flush()
def func_enhance(dir_model_pre, QP, PreIndex_list, CmpIndex_list, SubIndex_list): """Enhance PQFs or non-PQFs, record dpsnr, dssim and enhanced frames.""" global enhanced_list, sum_dpsnr, sum_dssim tf.reset_default_graph() ### Defind enhancement process x1 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # previous x2 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # current x3 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # subsequent if QP in net1_list: is_training = tf.placeholder_with_default(False, shape=()) x1to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x1, False) x3to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x3, True) if QP in net1_list: x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training) else: x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2) saver = tf.train.Saver() with tf.Session(config=config) as sess: # Restore model model_path = os.path.join(dir_model_pre, "model_step2.ckpt-" + str(QP)) saver.restore(sess, model_path) nfs = len(CmpIndex_list) sum_dpsnr_part = 0.0 sum_dssim_part = 0.0 for ite_frame in range(nfs): # Load frames pre_frame = y_import(CmpVideo_path, height, width, 1, PreIndex_list[ite_frame])[:, :, :, np.newaxis] / 255.0 cmp_frame = y_import(CmpVideo_path, height, width, 1, CmpIndex_list[ite_frame])[:, :, :, np.newaxis] / 255.0 sub_frame = y_import(CmpVideo_path, height, width, 1, SubIndex_list[ite_frame])[:, :, :, np.newaxis] / 255.0 # if cmp frame is plane? if isplane(cmp_frame): continue # if PQF frames are plane? if isplane(pre_frame): pre_frame = np.copy(cmp_frame) if isplane(sub_frame): sub_frame = np.copy(cmp_frame) # Enhance if QP in net1_list: enhanced_frame = sess.run(x2_enhanced, feed_dict={ x1: pre_frame, x2: cmp_frame, x3: sub_frame, is_training: False }) else: enhanced_frame = sess.run(x2_enhanced, feed_dict={ x1: pre_frame, x2: cmp_frame, x3: sub_frame }) # Record for output video enhanced_list[CmpIndex_list[ite_frame]] = np.squeeze( enhanced_frame) # Evaluate and accumulate dpsnr raw_frame = np.squeeze( y_import(RawVideo_path, height, width, 1, CmpIndex_list[ite_frame])) / 255.0 cmp_frame = np.squeeze(cmp_frame) enhanced_frame = np.squeeze(enhanced_frame) raw_frame = np.float32(raw_frame) cmp_frame = np.float32(cmp_frame) psnr_ori = compare_psnr(cmp_frame, raw_frame, data_range=1.0) psnr_aft = compare_psnr(enhanced_frame, raw_frame, data_range=1.0) ssim_ori = compare_ssim(cmp_frame, raw_frame, data_range=1.0) ssim_aft = compare_ssim(enhanced_frame, raw_frame, data_range=1.0) sum_dpsnr_part += psnr_aft - psnr_ori sum_dssim_part += ssim_aft - ssim_ori print("%d | %d at QP = %d" % (ite_frame + 1, nfs, QP), end="\r") print(" ", end="\r") sum_dpsnr += sum_dpsnr_part sum_dssim += sum_dssim_part average_dpsnr = sum_dpsnr_part / nfs average_dssim = sum_dssim_part / nfs print("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d" % (average_dpsnr, average_dssim, nfs), flush=True) file_object.write("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d\n" % (average_dpsnr, average_dssim, nfs)) file_object.flush()
def func_enhance(dir_model_pre, QP, PreIndex_list, CmpIndex_list, SubIndex_list): """Enhance PQFs or non-PQFs, record dpsnr, dssim and enhanced frames.""" #ywz global ywz_times global enhanced_list, sum_dpsnr, sum_dssim tf.reset_default_graph() ### Defind enhancement process x1 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # previous x2 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # current x3 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # subsequent if QP in net1_list: is_training = tf.placeholder_with_default(False, shape=()) x1to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x1, False) x3to2 = net_MFCNN.warp_img(BATCH_SIZE, x2, x3, True) if QP in net1_list: x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training) else: x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2) saver = tf.train.Saver() with tf.Session(config=config) as sess: # ywz if ywz_times == 0: options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # # Restore model model_path = os.path.join(dir_model_pre, "model_step2.ckpt-0") # saver.restore(sess, model_path) module_file = tf.train.latest_checkpoint(model_path) # with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if module_file is not None: saver.restore(sess, module_file) nfs = len(CmpIndex_list) sum_dpsnr_part = 0.0 sum_dssim_part = 0.0 for ite_frame in range(nfs): # Load frames pre_frame = y_import(CmpVideo_path, height, width, 1, PreIndex_list[ite_frame])[:, :, :, np.newaxis] / 255.0 cmp_frame = y_import(CmpVideo_path, height, width, 1, CmpIndex_list[ite_frame])[:, :, :, np.newaxis] / 255.0 sub_frame = y_import(CmpVideo_path, height, width, 1, SubIndex_list[ite_frame])[:, :, :, np.newaxis] / 255.0 # if cmp frame is plane? if isplane(cmp_frame): continue # if PQF frames are plane? if isplane(pre_frame): pre_frame = np.copy(cmp_frame) if isplane(sub_frame): sub_frame = np.copy(cmp_frame) # ywz if ywz_times == 0: run_metadata = tf.RunMetadata() # Enhance if QP in net1_list: enhanced_frame = sess.run(x2_enhanced, options=options, feed_dict={ x1: pre_frame, x2: cmp_frame, x3: sub_frame, is_training: False }, run_metadata=run_metadata) else: enhanced_frame = sess.run(x2_enhanced, options=options, feed_dict={ x1: pre_frame, x2: cmp_frame, x3: sub_frame }, run_metadata=run_metadata) else: # Enhance if QP in net1_list: enhanced_frame = sess.run(x2_enhanced, feed_dict={ x1: pre_frame, x2: cmp_frame, x3: sub_frame, is_training: False }) else: enhanced_frame = sess.run(x2_enhanced, feed_dict={ x1: pre_frame, x2: cmp_frame, x3: sub_frame }) # # ywz # timeline record # Create the Timeline object, and write it to a json file if ywz_times == 0: fetched_timeline = timeline.Timeline(run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open("timeline/timeline_" + str(ite_frame) + ".json", 'w') as f: f.write(chrome_trace) # # Record for output video enhanced_list[CmpIndex_list[ite_frame]] = np.squeeze( enhanced_frame) # Evaluate and accumulate dpsnr raw_frame = np.squeeze( y_import(RawVideo_path, height, width, 1, CmpIndex_list[ite_frame])) / 255.0 cmp_frame = np.squeeze(cmp_frame) enhanced_frame = np.squeeze(enhanced_frame) raw_frame = np.float32(raw_frame) cmp_frame = np.float32(cmp_frame) psnr_ori = compare_psnr(cmp_frame, raw_frame, data_range=1.0) psnr_aft = compare_psnr(enhanced_frame, raw_frame, data_range=1.0) ssim_ori = compare_ssim(cmp_frame, raw_frame, data_range=1.0) ssim_aft = compare_ssim(enhanced_frame, raw_frame, data_range=1.0) sum_dpsnr_part += psnr_aft - psnr_ori sum_dssim_part += ssim_aft - ssim_ori print("\r %d | %d at QP = %d" % (ite_frame + 1, nfs, QP), end="") #ywz 只记一次 # ywz_times += 1 print(" ", end="\r") sum_dpsnr += sum_dpsnr_part sum_dssim += sum_dssim_part average_dpsnr = sum_dpsnr_part / nfs average_dssim = sum_dssim_part / nfs print("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d" % (average_dpsnr, average_dssim, nfs), flush=True) file_object.write("dPSNR: %.3f - dSSIM: %.3f - nfs: %4d\n" % (average_dpsnr, average_dssim, nfs)) file_object.flush()