Example #1
0
def step_fn(inputs):
    img_with_shadow, shadow_mask, img_no_shadow, input_pureflash = inputs
    gray_pureflash = 0.33 * (input_pureflash[..., 0:1] +
                             input_pureflash[..., 1:2] +
                             input_pureflash[..., 2:3])
    # bad_mask = detect_shadow(img_with_shadow, input_pureflash)
    shadow_mask_layer = UNet_SE(tf.concat([img_with_shadow, gray_pureflash],
                                          axis=3),
                                output_channel=3,
                                ext='Ref_')
    no_shadow_layer = UNet_SE(tf.concat([img_with_shadow, shadow_mask_layer],
                                        axis=3),
                              ext='Trans_')
    lossDict["percep_t"] = 0.1 * compute_percep_loss(
        img_no_shadow, no_shadow_layer, reuse=False)
    lossDict["percep_r"] = 0.1 * compute_percep_loss(
        shadow_mask, shadow_mask_layer, reuse=True)
    lossDict["total"] = lossDict["percep_t"] + lossDict["percep_r"]
    tf_psnr = tf.image.psnr(img_no_shadow[0], no_shadow_layer[0], 1.0)
    encoded_concat = encode_jpeg(
        concat_img((img_with_shadow[0], no_shadow_layer[0], img_no_shadow[0],
                    input_pureflash[0], shadow_mask_layer[0], shadow_mask[0])))

    train_vars = tf.trainable_variables()

    R_vars = [var for var in train_vars if 'Ref_' in var.name]
    T_vars = [var for var in train_vars if 'Trans_' in var.name]
    all_vars = [var for var in train_vars if 'g_' in var.name]
    opt = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(
        lossDict["total"], var_list=all_vars)
    with tf.control_dependencies([opt]):
        return tf.identity(loss), tf_psnr, encoded_concat
Example #2
0
val_ds=val_ds.map(lambda x:gen_shadow(x,mask_file_list)).batch(2*BATCH_SIZE)

print(train_ds)

iterator = tf.data.Iterator.from_structure(train_ds.output_types,
                                           train_ds.output_shapes)
img_with_shadow,shadow_mask,img_no_shadow,input_pureflash = iterator.get_next()

training_init_op = iterator.make_initializer(train_ds)
validation_init_op = iterator.make_initializer(val_ds)

with tf.variable_scope(tf.get_variable_scope()):

    gray_pureflash = 0.33 * (input_pureflash[...,0:1] + input_pureflash[...,1:2] + input_pureflash[...,2:3])
    # bad_mask = detect_shadow(img_with_shadow, input_pureflash)
    shadow_mask_layer = UNet_SE(tf.concat([img_with_shadow, gray_pureflash], axis=3), output_channel = 1, ext='Ref_')
                        
    no_shadow_layer = UNet_SE(tf.concat([img_with_shadow, shadow_mask_layer], axis=3), ext='Trans_')
    lossDict["percep_t"] = 0.1 * compute_percep_loss(img_no_shadow, no_shadow_layer, reuse=False)    
    lossDict["percep_r"]=0.1* tf.reduce_mean(tf.math.abs(shadow_mask-shadow_mask_layer))
    # lossDict["percep_r"] = 0.1 * compute_percep_loss(shadow_mask, shadow_mask_layer, reuse=True) 
    lossDict["total"] = lossDict["percep_t"] + lossDict["percep_r"]
    tf_psnr=tf.math.reduce_mean(tf.image.psnr(tf.clip_by_value(img_no_shadow,0,1),
                        tf.clip_by_value(no_shadow_layer,0,1),1.0))
    encoded_concat=encode_jpeg(
        concat_img((img_with_shadow[0],no_shadow_layer[0],img_no_shadow[0],
            input_pureflash[0],tf.image.grayscale_to_rgb(tf.clip_by_value(shadow_mask_layer[0],0,1)),
            tf.image.grayscale_to_rgb(shadow_mask[0]))))


iterator = tf.data.Iterator.from_structure(train_ds.output_types,
                                           train_ds.output_shapes)
img_with_shadow, shadow_mask, img_no_shadow, input_pureflash = iterator.get_next(
)

training_init_op = iterator.make_initializer(train_ds)
validation_init_op = iterator.make_initializer(val_ds)

with tf.variable_scope(tf.get_variable_scope()):

    gray_pureflash = 0.33 * (input_pureflash[..., 0:1] +
                             input_pureflash[..., 1:2] +
                             input_pureflash[..., 2:3])
    # bad_mask = detect_shadow(img_with_shadow, input_pureflash)
    shadow_mask_layer = UNet_SE(img_with_shadow, output_channel=3, ext='Ref_')
    no_shadow_layer = UNet_SE(tf.concat([img_with_shadow, shadow_mask_layer],
                                        axis=3),
                              ext='Trans_')
    lossDict["percep_t"] = 0.1 * compute_percep_loss(
        img_no_shadow, no_shadow_layer, reuse=False)
    lossDict["percep_r"] = 0.1 * compute_percep_loss(
        shadow_mask, shadow_mask_layer, reuse=True)
    lossDict["total"] = lossDict["percep_t"] + lossDict["percep_r"]
    tf_psnr = tf.image.psnr(img_no_shadow[0], no_shadow_layer[0], 1.0)
    encoded_concat = encode_jpeg(
        concat_img((img_with_shadow[0], no_shadow_layer[0], img_no_shadow[0],
                    input_pureflash[0], shadow_mask_layer[0], shadow_mask[0])))

train_vars = tf.trainable_variables()
    input_ambient = tf.placeholder(tf.float32, shape=[None, None, None, 3])
    input_pureflash = tf.placeholder(tf.float32, shape=[None, None, None, 3])
    input_flash = tf.placeholder(tf.float32, shape=[None, None, None, 3])
    reflection = tf.placeholder(tf.float32, shape=[None, None, None, 3])
    target = tf.placeholder(tf.float32, shape=[None, None, None, 3])

    mask_shadow = tf.cast(tf.greater(input_pureflash, 0.02), tf.float32)
    mask_highlight = tf.cast(tf.less(input_flash, 0.96), tf.float32)
    mask_shadow_highlight = mask_shadow * mask_highlight

    gray_pureflash = 0.33 * (input_pureflash[..., 0:1] +
                             input_pureflash[..., 1:2] +
                             input_pureflash[..., 2:3])
    bad_mask = detect_shadow(input_ambient, input_pureflash)
    reflection_layer = UNet_SE(tf.concat(
        [input_ambient, gray_pureflash, (-bad_mask + 1)], axis=3),
                               output_channel=3,
                               ext='Ref_')
    transmission_layer = UNet_SE(tf.concat(
        [input_ambient, reflection_layer, (-bad_mask + 1)], axis=3),
                                 ext='Trans_')
    lossDict["percep_t"] = 0.1 * compute_percep_loss(
        target, transmission_layer, reuse=False)
    lossDict["percep_r"] = 0.1 * compute_percep_loss(
        reflection, reflection_layer, reuse=True)
    lossDict["total"] = lossDict["percep_t"] + lossDict["percep_r"]

train_vars = tf.trainable_variables()

R_vars = [var for var in train_vars if 'Ref_' in var.name]
T_vars = [var for var in train_vars if 'Trans_' in var.name]
all_vars = [var for var in train_vars if 'g_' in var.name]
    # # input_flash=tf.placeholder(tf.float32,shape=[None,None,None,3])
    # shadow_mask=tf.placeholder(tf.float32,shape=[None,None,None,3])
    # img_no_shadow=tf.placeholder(tf.float32,shape=[None,None,None,3])

    # mask_shadow = tf.cast(tf.greater(input_pureflash, 0.02), tf.float32)
    # mask_highlight = tf.cast(tf.less(input_flash, 0.96), tf.float32)
    # mask_shadow_highlight = mask_shadow * mask_highlight

    gray_pureflash = 0.33 * (input_pureflash[..., 0:1] +
                             input_pureflash[..., 1:2] +
                             input_pureflash[..., 2:3])
    # bad_mask = detect_shadow(img_with_shadow, input_pureflash)
    shadow_mask_layer = shadow_mask
    # shadow_mask_layer = UNet_SE(tf.concat([img_with_shadow, gray_pureflash], axis=3), output_channel = 3, ext='Ref_')
    transmission_layer = UNet_SE(tf.concat(
        [img_with_shadow, shadow_mask_layer], axis=3),
                                 ext='Trans_')
    lossDict["percep_t"] = 0.1 * compute_percep_loss(
        img_no_shadow, transmission_layer, reuse=False)
    lossDict["percep_r"] = tf.constant(0)
    # lossDict["percep_r"] = 0.1 * compute_percep_loss(shadow_mask, shadow_mask_layer, reuse=True)
    lossDict["total"] = lossDict["percep_t"]  # + lossDict["percep_r"]
    tf_psnr = tf.image.psnr(img_no_shadow, transmission_layer, 1)

train_vars = tf.trainable_variables()

R_vars = [var for var in train_vars if 'Ref_' in var.name]
T_vars = [var for var in train_vars if 'Trans_' in var.name]
all_vars = [var for var in train_vars if 'g_' in var.name]

for var in R_vars:
iterator = tf.data.Iterator.from_structure(train_ds.output_types,
                                           train_ds.output_shapes)
ref_gt,img_with_shadow,input_pureflash,tran_gt = iterator.get_next()

training_init_op = iterator.make_initializer(train_ds)
validation_init_op = iterator.make_initializer(val_ds)

with tf.variable_scope(tf.get_variable_scope()):

    gray_pureflash = 0.25 * (input_pureflash[...,0:1] + input_pureflash[...,1:2] + input_pureflash[...,2:3]+input_pureflash[...,3:4])
    # bad_mask = detect_shadow(img_with_shadow, input_pureflash)

    
    if NOFLASH:
        reflection_layer = UNet_SE(tf.concat([img_with_shadow], axis=3), output_channel = 4, ext='Ref_')
    else:
        reflection_layer = UNet_SE(tf.concat([img_with_shadow, gray_pureflash], axis=3), output_channel = 4, ext='Ref_')

    tran_layer = UNet_SE(tf.concat([img_with_shadow, reflection_layer], axis=3),output_channel = 4, ext='Trans_')
    # lossDict["percep_t"] = 0.1 * compute_percep_loss(ref_gt, tran_layer, reuse=False)    
    lossDict["percep_t"]=0.1* tf.reduce_mean(tf.abs(tran_gt- tran_layer))
    # lossDict["percep_r"] = 0.1 * compute_percep_loss(tran_gt, reflection_layer, reuse=True) 
    lossDict["percep_r"]=0.1* tf.reduce_mean(tf.abs(ref_gt-reflection_layer))
    lossDict["total"] = lossDict["percep_t"] + lossDict["percep_r"]
    if RGB_PSNR:
        tf_psnr=tf.math.reduce_mean(tf.image.psnr(tf.clip_by_value(linref2srgb(ref_gt[0]),0,1),
                        tf.clip_by_value(linref2srgb(tran_layer[0]),0,1),1.0))
    else:
        tf_psnr=tf.math.reduce_mean(tf.image.psnr(tf.clip_by_value(ref_gt,0,1),
                        tf.clip_by_value(tran_layer,0,1),1.0))