def video_load_crop(entry, input_size): fg_path, bg_path, previous_path, flo_path = entry alpha, fg = reader.read_fg_img(fg_path) fg = fg.astype(dtype=np.float) # potentially very big bg = cv2.imread(bg_path).astype(dtype=np.float) flo = reader.read_flow(flo_path) prev_alpha, _ = reader.read_fg_img(previous_path) warped_alpha = flow.warp_img(prev_alpha, flo) warped_alpha = np.repeat(warped_alpha[:, :, np.newaxis], 3, axis=2) crop_type = [(320, 320), (480, 480), (640, 640)] # we crop images of different sizes crop_h, crop_w = crop_type[np.random.randint(0, len(crop_type))] fg_h, fg_w = fg.shape[:2] if fg_h < crop_h or fg_w < crop_w: # in that case the image is not too big, and we have to add padding alpha = alpha.reshape((alpha.shape[0], alpha.shape[1], 1)) cat = np.concatenate((fg, alpha, warped_alpha), axis=2) cropped_cat = get_padded_img(cat, crop_h, crop_w) fg, alpha, warped_alpha = np.split(cropped_cat, indices_or_sections=[3, 4], axis=2) # otherwise, the fg is likely to be HRes, we directly crop it and dismiss the original image # to avoid manipulation big images fg_h, fg_w = fg.shape[:2] i, j = np.random.randint(0, fg_h - crop_h + 1), np.random.randint( 0, fg_w - crop_w + 1) fg = fg[i:i + crop_h, j:j + crop_h] alpha = alpha[i:i + crop_h, j:j + crop_h] warped_alpha = warped_alpha[i:i + crop_h, j:j + crop_h] # randomly picks top-left corner bg_crop_h, bg_crop_w = int(np.ceil(crop_h * bg.shape[0] / fg.shape[0])), \ int(np.ceil(crop_w * bg.shape[1] / fg.shape[1])) padded_bg = get_padded_img(bg, bg_crop_h, bg_crop_w) i, j = np.random.randint(0, bg.shape[0] - bg_crop_h + 1), np.random.randint( 0, bg.shape[1] - bg_crop_w + 1) cropped_bg = padded_bg[i:i + bg_crop_h, j:j + bg_crop_w] bg = cv2.resize(src=cropped_bg, dsize=input_size, interpolation=cv2.INTER_LINEAR) fg = cv2.resize(fg, input_size, interpolation=cv2.INTER_LINEAR) alpha = cv2.resize(alpha, input_size, interpolation=cv2.INTER_LINEAR) warped_alpha = cv2.resize(warped_alpha, input_size, interpolation=cv2.INTER_LINEAR) cmp = reader.create_composite_image(fg, bg, alpha) cmp -= params.VGG_MEAN bg -= params.VGG_MEAN # inp = np.concatenate((cmp, # bg, # trimap.reshape((h, w, 1))), axis=2) label = alpha.reshape((alpha.shape[0], alpha.shape[1], 1)) # warped_alpha = np.dstack([warped_alpha[:, :, np.newaxis]] * 3) return cmp, bg, label, warped_alpha, fg
Height = 240 Width = 416 Channel = 1 batch_size = 1 #Session config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) #Placeholder x1 = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel]) x2 = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel]) x3 = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel]) ## MC-subnet x1to2 = flow.warp_img(batch_size, x2, x1, False) x3to2 = flow.warp_img(batch_size, x2, x3, True) ## QE-subnet x2_enhanced = net.network(x1to2, x2, x3to2) ##Import data PQF_Frame_93_Y, PQF_Frame_93_U, PQF_Frame_93_V = data.input_data( Height, Width, 'Frame_93') non_PQF_Frame_96_Y, non_PQF_Frame_96_U, non_PQF_Frame_96_V = data.input_data( Height, Width, 'Frame_96') PQF_Frame_97_Y, PQF_Frame_97_U, PQF_Frame_97_V = data.input_data( Height, Width, 'Frame_97') ##Load model saver = tf.train.Saver()
def enhance(QP,input,Non_PQF_indices,pre_PQF_indices,sub_PQF_indices): # Recommended MF-CNN model if QP == 22: model_index = 1230000 elif QP == 27: model_index = 275000 elif QP == 32: model_index = 1227500 elif QP == 37: model_index = 1122500 elif QP == 42: model_index = 1250000 model_path = "./Model_MFCNN/QP" + str(QP) + "/model.ckpt-" + str(model_index) # video information nfs, height, width = input.shape nfs = len(Non_PQF_indices) input = input[:,:,:,np.newaxis] input = input / 255.0 enhanced_frames = np.zeros([nfs,height,width]) with tf.Graph().as_default() as g: x1 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # previous x2 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # current x3 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL]) # subsequent is_training = tf.placeholder_with_default(False, shape=()) x1to2 = flow.warp_img(BATCH_SIZE, x2, x1, False) x3to2 = flow.warp_img(BATCH_SIZE, x2, x3, True) if (QP == 37) or (QP == 42): x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training) elif (QP == 22) or (QP == 27) or (QP == 32): x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2) # restore vars above saver = tf.train.Saver() with tf.Session(config = config) as sess: # restore model saver.restore(sess,model_path) # enhance start_time = time.time() for ite in range(nfs): x1_feed = input[pre_PQF_indices[ite]:pre_PQF_indices[ite]+1,:,:,:] x2_feed = input[Non_PQF_indices[ite]:Non_PQF_indices[ite]+1,:,:,:] x3_feed = input[sub_PQF_indices[ite]:sub_PQF_indices[ite]+1,:,:,:] x2_enhanced_frame = sess.run(x2_enhanced, feed_dict={x1: x1_feed, x2: x2_feed, x3: x3_feed, is_training:False}) enhanced_frames[ite] = np.squeeze(x2_enhanced_frame) print("\r"+str(ite+1)+" | "+str(nfs), end="", flush=True) end_time = time.time() average_fps = nfs / (end_time - start_time) print("") return enhanced_frames, average_fps
def train(): path1 = "../../dataset_256/train_256_b8_LD37_HP1.h5" path2 = "../../dataset_256/train_256_b8_LD37_HA1.h5" path3 = "../../dataset_256/train_256_b8_LD37_H1.h5" f1 = h5py.File(path1, 'r') f2 = h5py.File(path2, 'r') f3 = h5py.File(path3, 'r') train_net = net.Conv_Net_Train() length_f = min(f1['data_cur'].shape[0],f2['data_cur'].shape[0],f3['data_cur'].shape[0]) lr = tf.placeholder(tf.float32,[]) x1 = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel]) x2 = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel]) x3 = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel]) x2_label = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel]) ## MC-subnet x1to2,flow1To2 = flow.warp_img(FLAGS.batch_size, x2, x1, False) x3to2,flow3To2 = flow.warp_img(FLAGS.batch_size, x2, x3, True) enhanced_image =train_net.enhanced_Net(x1to2,x2,x3to2,is_train=True,name='enhanced_Net') l2_loss_1 = tf.nn.l2_loss(x1to2 - x2) l2_loss_2 = tf.nn.l2_loss(x3to2 - x2) l2_loss_3 = tf.nn.l2_loss(enhanced_image-x2_label) a=0.01 b=1 loss_total = a*(l2_loss_1+l2_loss_2 ) + b*(l2_loss_3) configProt = tf.ConfigProto() configProt.gpu_options.allow_growth = True configProt.allow_soft_placement = True sess = tf.Session(config=configProt) optimizer = tf.train.AdamOptimizer(lr, name='AdamOptimizer') train_op = optimizer.minimize(loss_total) init = tf.global_variables_initializer() sess.run(init) var = tf.trainable_variables() variables_to_restore = [val for val in var if 'easyflow' in val.name ] saver = tf.train.Saver(variables_to_restore) saver.restore(sess, './pre_flow_model/mdoel-ckpt-15') print('load successfully') lr_value = 0.0001 saver = tf.train.Saver(var_list=tf.trainable_variables(),max_to_keep=60) # print('params:',get_num_params()) for epoch in range(1, FLAGS.max_epochs+1): if epoch % 10 == 0: lr_value = lr_value * 0.1 start1 = 0 start2 = 0 start3 = 0 for itr in range((length_f*3)//8): time_start = time.time() if itr<=(length_f//8): image_pre_batch,image_cur_batch,image_aft_batch,label_batch, start1 = train_batch_h5(f1, length_f, start1, batch_size=FLAGS.batch_size) elif itr>(length_f//8) and itr<= (length_f*2)//8: image_pre_batch,image_cur_batch,image_aft_batch,label_batch, start2 = train_batch_h5(f2, length_f, start2, batch_size=FLAGS.batch_size) else: image_pre_batch,image_cur_batch,image_aft_batch,label_batch, start3 = train_batch_h5(f3, length_f, start3, batch_size=FLAGS.batch_size) feed_dict = {x1:image_pre_batch,x2:image_cur_batch,x3:image_aft_batch,x2_label:label_batch, lr: lr_value} # _, l2_loss_value, MC_image, lr_value_net = sess.run([train_op, loss_total, x1to2, lr],feed_dict) _, l2_loss_MC,l2_loss_3_net ,MC_image,enhanced, lr_value_net = sess.run([train_op, l2_loss_1,l2_loss_3, x1to2,enhanced_image, lr],feed_dict) time_end = time.time() time_step = time_end - time_start lr_value_net=np.mean(lr_value_net) if itr % 10 == 0: l1_loss_value = np.mean(np.abs((enhanced) - (label_batch))) total_time = time_step*((length_f*3)//8)*(FLAGS.max_epochs+1-epoch)/3600 # print('itr:%d l1_loss:%f lr:%f time_step:%f total_time:%f' % (itr, l1_loss_value * 255.0, lr_value_net, time_step, time_step)) print("===> Epoch[{}]({}/{}): lr:{:.10f} Loss_l1: {:.04f}: time_step:{:.04f} total_time:{:.04f}".format(epoch, itr, (length_f*3)//8,lr_value_net, l1_loss_value * 255.0,time_step,total_time)) # print('===>Epoch[]({:.04f}/{:.04f}) '.format(l2_loss_MC, l2_loss_3_net)) checkpoint_path='./266_SDTS/' os.makedirs("./266_SDTS/",exist_ok=True) saver = tf.train.Saver(var_list=tf.trainable_variables()) saver.save(sess, checkpoint_path+'model-ckpt', global_step=epoch)