Example #1
0
def video_load_crop(entry, input_size):
    fg_path, bg_path, previous_path, flo_path = entry
    alpha, fg = reader.read_fg_img(fg_path)
    fg = fg.astype(dtype=np.float)  # potentially very big
    bg = cv2.imread(bg_path).astype(dtype=np.float)
    flo = reader.read_flow(flo_path)
    prev_alpha, _ = reader.read_fg_img(previous_path)
    warped_alpha = flow.warp_img(prev_alpha, flo)
    warped_alpha = np.repeat(warped_alpha[:, :, np.newaxis], 3, axis=2)
    crop_type = [(320, 320), (480, 480),
                 (640, 640)]  # we crop images of different sizes
    crop_h, crop_w = crop_type[np.random.randint(0, len(crop_type))]
    fg_h, fg_w = fg.shape[:2]
    if fg_h < crop_h or fg_w < crop_w:
        # in that case the image is not too big, and we have to add padding
        alpha = alpha.reshape((alpha.shape[0], alpha.shape[1], 1))
        cat = np.concatenate((fg, alpha, warped_alpha), axis=2)
        cropped_cat = get_padded_img(cat, crop_h, crop_w)
        fg, alpha, warped_alpha = np.split(cropped_cat,
                                           indices_or_sections=[3, 4],
                                           axis=2)
    # otherwise, the fg is likely to be HRes, we directly crop it and dismiss the original image
    # to avoid manipulation big images
    fg_h, fg_w = fg.shape[:2]
    i, j = np.random.randint(0, fg_h - crop_h + 1), np.random.randint(
        0, fg_w - crop_w + 1)
    fg = fg[i:i + crop_h, j:j + crop_h]
    alpha = alpha[i:i + crop_h, j:j + crop_h]
    warped_alpha = warped_alpha[i:i + crop_h, j:j + crop_h]
    # randomly picks top-left corner
    bg_crop_h, bg_crop_w = int(np.ceil(crop_h * bg.shape[0] / fg.shape[0])), \
                           int(np.ceil(crop_w * bg.shape[1] / fg.shape[1]))
    padded_bg = get_padded_img(bg, bg_crop_h, bg_crop_w)
    i, j = np.random.randint(0,
                             bg.shape[0] - bg_crop_h + 1), np.random.randint(
                                 0, bg.shape[1] - bg_crop_w + 1)
    cropped_bg = padded_bg[i:i + bg_crop_h, j:j + bg_crop_w]
    bg = cv2.resize(src=cropped_bg,
                    dsize=input_size,
                    interpolation=cv2.INTER_LINEAR)
    fg = cv2.resize(fg, input_size, interpolation=cv2.INTER_LINEAR)
    alpha = cv2.resize(alpha, input_size, interpolation=cv2.INTER_LINEAR)
    warped_alpha = cv2.resize(warped_alpha,
                              input_size,
                              interpolation=cv2.INTER_LINEAR)

    cmp = reader.create_composite_image(fg, bg, alpha)
    cmp -= params.VGG_MEAN
    bg -= params.VGG_MEAN
    # inp = np.concatenate((cmp,
    #                       bg,
    #                       trimap.reshape((h, w, 1))), axis=2)
    label = alpha.reshape((alpha.shape[0], alpha.shape[1], 1))
    # warped_alpha = np.dstack([warped_alpha[:, :, np.newaxis]] * 3)

    return cmp, bg, label, warped_alpha, fg
Example #2
0
Height = 240
Width = 416
Channel = 1
batch_size = 1

#Session
config = tf.ConfigProto(allow_soft_placement=True)
sess = tf.Session(config=config)

#Placeholder
x1 = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel])
x2 = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel])
x3 = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel])

## MC-subnet
x1to2 = flow.warp_img(batch_size, x2, x1, False)
x3to2 = flow.warp_img(batch_size, x2, x3, True)

## QE-subnet
x2_enhanced = net.network(x1to2, x2, x3to2)

##Import data
PQF_Frame_93_Y, PQF_Frame_93_U, PQF_Frame_93_V = data.input_data(
    Height, Width, 'Frame_93')
non_PQF_Frame_96_Y, non_PQF_Frame_96_U, non_PQF_Frame_96_V = data.input_data(
    Height, Width, 'Frame_96')
PQF_Frame_97_Y, PQF_Frame_97_U, PQF_Frame_97_V = data.input_data(
    Height, Width, 'Frame_97')

##Load model
saver = tf.train.Saver()
Example #3
0
def enhance(QP,input,Non_PQF_indices,pre_PQF_indices,sub_PQF_indices):

    # Recommended MF-CNN model
    if QP == 22:
        model_index = 1230000
    elif QP == 27:
        model_index = 275000
    elif QP == 32:
        model_index = 1227500
    elif QP == 37:
        model_index = 1122500
    elif QP == 42:
        model_index = 1250000

    model_path = "./Model_MFCNN/QP" + str(QP) + "/model.ckpt-" + str(model_index)

    # video information
    nfs, height, width = input.shape
    nfs = len(Non_PQF_indices)

    input = input[:,:,:,np.newaxis]
    input = input / 255.0

    enhanced_frames = np.zeros([nfs,height,width])

    with tf.Graph().as_default() as g:

        x1 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # previous
        x2 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # current
        x3 = tf.placeholder(tf.float32, [BATCH_SIZE, height, width, CHANNEL])  # subsequent
        is_training = tf.placeholder_with_default(False, shape=())

        x1to2 = flow.warp_img(BATCH_SIZE, x2, x1, False)
        x3to2 = flow.warp_img(BATCH_SIZE, x2, x3, True)

        if (QP == 37) or (QP == 42):
            x2_enhanced = net_MFCNN.network(x1to2, x2, x3to2, is_training)
        elif (QP == 22) or (QP == 27) or (QP == 32):
            x2_enhanced = net_MFCNN.network2(x1to2, x2, x3to2)

        # restore vars above
        saver = tf.train.Saver()

        with tf.Session(config = config) as sess:

            # restore model
            saver.restore(sess,model_path)

            # enhance
            start_time = time.time()
            for ite in range(nfs):
                x1_feed = input[pre_PQF_indices[ite]:pre_PQF_indices[ite]+1,:,:,:]
                x2_feed = input[Non_PQF_indices[ite]:Non_PQF_indices[ite]+1,:,:,:]
                x3_feed = input[sub_PQF_indices[ite]:sub_PQF_indices[ite]+1,:,:,:]

                x2_enhanced_frame = sess.run(x2_enhanced, feed_dict={x1: x1_feed, x2: x2_feed, x3: x3_feed, is_training:False})
                enhanced_frames[ite] = np.squeeze(x2_enhanced_frame)

                print("\r"+str(ite+1)+" | "+str(nfs), end="", flush=True)

            end_time = time.time()
            average_fps = nfs / (end_time - start_time)
            print("")

    return enhanced_frames, average_fps
Example #4
0
def train():
    path1 = "../../dataset_256/train_256_b8_LD37_HP1.h5"
    path2 = "../../dataset_256/train_256_b8_LD37_HA1.h5"
    path3 = "../../dataset_256/train_256_b8_LD37_H1.h5"
    f1 = h5py.File(path1, 'r')
    f2 = h5py.File(path2, 'r')
    f3 = h5py.File(path3, 'r')
    train_net = net.Conv_Net_Train()
    length_f = min(f1['data_cur'].shape[0],f2['data_cur'].shape[0],f3['data_cur'].shape[0])

    lr = tf.placeholder(tf.float32,[])

    x1 = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel])
    x2 = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel])
    x3 = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel])
    x2_label = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel])
    
    ## MC-subnet
    x1to2,flow1To2 = flow.warp_img(FLAGS.batch_size, x2, x1, False)
    x3to2,flow3To2 = flow.warp_img(FLAGS.batch_size, x2, x3, True)

    enhanced_image =train_net.enhanced_Net(x1to2,x2,x3to2,is_train=True,name='enhanced_Net')

    l2_loss_1 = tf.nn.l2_loss(x1to2 - x2)
    l2_loss_2 = tf.nn.l2_loss(x3to2 - x2)

    l2_loss_3 = tf.nn.l2_loss(enhanced_image-x2_label)
    a=0.01
    b=1
    loss_total = a*(l2_loss_1+l2_loss_2 ) + b*(l2_loss_3) 

    configProt = tf.ConfigProto()
    configProt.gpu_options.allow_growth = True
    configProt.allow_soft_placement = True

    sess = tf.Session(config=configProt)

    optimizer = tf.train.AdamOptimizer(lr, name='AdamOptimizer')
    train_op = optimizer.minimize(loss_total)
    init = tf.global_variables_initializer()
    sess.run(init)

    var = tf.trainable_variables()
    variables_to_restore = [val for val in var if 'easyflow' in val.name ]
    saver = tf.train.Saver(variables_to_restore)
    saver.restore(sess, './pre_flow_model/mdoel-ckpt-15')
    print('load successfully')
    lr_value = 0.0001
    saver = tf.train.Saver(var_list=tf.trainable_variables(),max_to_keep=60)
    # print('params:',get_num_params())
    for epoch in range(1, FLAGS.max_epochs+1):
        if epoch % 10 == 0:
            lr_value = lr_value * 0.1
        start1 = 0
        start2 = 0
        start3 = 0
        for itr in range((length_f*3)//8):
            time_start = time.time()
            if itr<=(length_f//8):
                image_pre_batch,image_cur_batch,image_aft_batch,label_batch, start1 = train_batch_h5(f1, length_f, start1, batch_size=FLAGS.batch_size)
            elif itr>(length_f//8) and itr<= (length_f*2)//8:
                image_pre_batch,image_cur_batch,image_aft_batch,label_batch, start2 = train_batch_h5(f2, length_f, start2, batch_size=FLAGS.batch_size)
            else:
                image_pre_batch,image_cur_batch,image_aft_batch,label_batch, start3 = train_batch_h5(f3, length_f, start3, batch_size=FLAGS.batch_size)

            feed_dict = {x1:image_pre_batch,x2:image_cur_batch,x3:image_aft_batch,x2_label:label_batch, lr: lr_value}

            # _, l2_loss_value, MC_image, lr_value_net = sess.run([train_op, loss_total, x1to2, lr],feed_dict)
            _, l2_loss_MC,l2_loss_3_net ,MC_image,enhanced, lr_value_net = sess.run([train_op, l2_loss_1,l2_loss_3, x1to2,enhanced_image, lr],feed_dict)
            time_end = time.time()
            time_step = time_end - time_start
            lr_value_net=np.mean(lr_value_net)
            if itr % 10 == 0:
                l1_loss_value = np.mean(np.abs((enhanced) - (label_batch)))
    
                total_time = time_step*((length_f*3)//8)*(FLAGS.max_epochs+1-epoch)/3600
                # print('itr:%d l1_loss:%f  lr:%f time_step:%f  total_time:%f' % (itr, l1_loss_value * 255.0, lr_value_net, time_step, time_step))
                print("===> Epoch[{}]({}/{}): lr:{:.10f} Loss_l1: {:.04f}: time_step:{:.04f} total_time:{:.04f}".format(epoch, itr, (length_f*3)//8,lr_value_net, l1_loss_value * 255.0,time_step,total_time))
                # print('===>Epoch[]({:.04f}/{:.04f})  '.format(l2_loss_MC, l2_loss_3_net))
        checkpoint_path='./266_SDTS/'
        os.makedirs("./266_SDTS/",exist_ok=True)
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, checkpoint_path+'model-ckpt', global_step=epoch)