Exemple #1
0
def _conf_slice_refinement_net(conf_slice, img_features, is_training):
    """
    Small network with shared weigts leearning to refine the confidence volume using the input texture.
    Uses shared weigts for all.
    """
    h = int(conf_slice.shape[1])
    w = int(conf_slice.shape[2])
    #with tf.variable_scope("conf_slice_refinement_size_%d_%d"%(h,w), reuse=tf.AUTO_REUSE):
    with tf.variable_scope("conf_slice_refinement", reuse=tf.AUTO_REUSE):
        conf_slice_out = conv2d(inputs=conf_slice, filters=16, kernel_size=3, name = 'refine_conf_conv1' )
        conf_slice_out = tf.layers.batch_normalization(inputs=conf_slice_out, training=is_training, name='refine_conf_bn1' )
        conf_slice_out = tf.nn.leaky_relu(features=conf_slice_out, name='refine_conf_leaky1')
        conf_slice_out = resnet_block(inputs=conf_slice_out, filters=16, kernel_size=3, dilation_rate=1,name='refine_conf_res1')
        conf_slice_out = resnet_block(inputs=conf_slice_out, filters=16, kernel_size=3, dilation_rate=2, name='refine_conf_res2')
        #Resize image features
        image_out = tf.image.resize_bilinear(img_features, conf_slice.shape[1:3])
        #
        concat_out = tf.concat([image_out,conf_slice_out],axis=3,name='refine_concat')
        #
        concat_out = resnet_block(inputs=concat_out, filters=32, kernel_size=3, dilation_rate=4, name='refine_concat_res0')
        #concat_out = resnet_block(inputs=concat_out, filters=32, kernel_size=3, dilation_rate=8, name='refine_concat_res1')
        #concat_out = resnet_block(inputs=concat_out, filters=32, kernel_size=3, dilation_rate=1, name='refine_concat_res2')
        #concat_out = resnet_block(inputs=concat_out, filters=32, kernel_size=3, dilation_rate=1, name='refine_concat_res3')
        #note can use tanh
        conf_residual = conv2d(inputs=concat_out, filters=1, kernel_size=3, name = 'refine_concat_conv1')
        conf_residual =  tf.nn.tanh(conf_residual, name='conf_residual_tanh')
        #
        refined_conf_slice = tf.clip_by_value(conf_slice + conf_residual, 0, 1)
            
        return refined_conf_slice
Exemple #2
0
def _tower_feature_net(img, is_training):
    """
    Definition for the feature network.
    
    img: the input tensor for a batch of images
    
    Return a tensor of features.
    """

    with tf.variable_scope("feature_tower", reuse=tf.AUTO_REUSE):
        out = conv2d(inputs=img, filters=32, kernel_size=3, name='tower_conv1')

        out = resnet_block(inputs=out,
                           filters=32,
                           kernel_size=3,
                           name='tower_res1')
        out = resnet_block(inputs=out,
                           filters=32,
                           kernel_size=3,
                           name='tower_res2')
        out = resnet_block(inputs=out,
                           filters=32,
                           kernel_size=3,
                           name='tower_res3')

        out = conv2d(inputs=out,
                     filters=32,
                     kernel_size=3,
                     strides=2,
                     name='tower_conv2')
        out = tf.layers.batch_normalization(inputs=out,
                                            training=is_training,
                                            name='tower_bn1')

        out = tf.nn.leaky_relu(features=out, name='tower_leaky1')

        out = conv2d(inputs=out,
                     filters=32,
                     kernel_size=3,
                     strides=2,
                     name='tower_conv3')
        out = tf.layers.batch_normalization(inputs=out,
                                            training=is_training,
                                            name='tower_bn2')
        out = tf.nn.leaky_relu(features=out, name='tower_leaky2')

        out = conv2d(inputs=out,
                     filters=32,
                     kernel_size=3,
                     strides=2,
                     name='tower_conv4')
        out = tf.layers.batch_normalization(inputs=out,
                                            training=is_training,
                                            name='tower_bn3')
        tower_feature = tf.nn.leaky_relu(features=out, name='tower_leaky3')

    return tower_feature
Exemple #3
0
def _image_feature_net(img,is_training):
    with tf.variable_scope("image_features"):
        image_out = conv2d(inputs=img, filters=16, kernel_size=3, name = 'refine_image_conv1')
        image_out = tf.layers.batch_normalization(inputs=image_out, training=is_training, name='refine_image_bn1')
        image_out = tf.nn.leaky_relu(features=image_out, name='refine_image_leaky1')
        image_out = resnet_block(inputs=image_out, filters=16, kernel_size=3, dilation_rate=1,name='refine_image_res1')
        image_out = resnet_block(inputs=image_out, filters=16, kernel_size=3, dilation_rate=2, name='refine_image_res2')
        return image_out
Exemple #4
0
def _blur_image_refinement_net(img,
                               upsampled_blur_image,
                               scale_name="1",
                               is_training=False):
    """
    Refines the upsampled blur image using the AIF image.
    This is the same block as stereonet.
    """
    with tf.variable_scope("guided_upsampling_scale_" + scale_name):
        #get features from the upsampled blur image
        blur_image_features = conv2d(inputs=upsampled_blur_image,
                                     filters=16,
                                     kernel_size=3,
                                     name='refine_blur_image_conv1_' +
                                     scale_name)
        blur_image_features = tf.layers.batch_normalization(
            inputs=blur_image_features,
            training=is_training,
            name='refine_blur_image_bn1_' + scale_name)
        blur_image_features = tf.nn.leaky_relu(
            features=blur_image_features,
            name='refine_blur_image_leaky1_' + scale_name)
        blur_image_features = resnet_block(inputs=blur_image_features,
                                           filters=16,
                                           kernel_size=3,
                                           dilation_rate=1,
                                           name='refine_blur_image_res1_' +
                                           scale_name)
        blur_image_features = resnet_block(inputs=blur_image_features,
                                           filters=16,
                                           kernel_size=3,
                                           dilation_rate=2,
                                           name='refine_blur_image_res2_' +
                                           scale_name)

        #get features from the AIF image (downsapled b4 if needed)
        aif_image_features = conv2d(inputs=img,
                                    filters=16,
                                    kernel_size=3,
                                    name='refine_image_conv1_' + scale_name)
        aif_image_features = tf.layers.batch_normalization(
            inputs=aif_image_features,
            training=is_training,
            name='refine_image_bn1_' + scale_name)
        aif_image_features = tf.nn.leaky_relu(features=aif_image_features,
                                              name='refine_image_leaky1_' +
                                              scale_name)
        aif_image_features = resnet_block(inputs=aif_image_features,
                                          filters=16,
                                          kernel_size=3,
                                          dilation_rate=1,
                                          name='refine_image_res1_' +
                                          scale_name)
        aif_image_features = resnet_block(inputs=aif_image_features,
                                          filters=16,
                                          kernel_size=3,
                                          dilation_rate=2,
                                          name='refine_image_res2_' +
                                          scale_name)
        #cat
        concat_out = tf.concat([aif_image_features, blur_image_features],
                               axis=3,
                               name='refine_concat_' + scale_name)
        #some seridual blocks
        concat_out = resnet_block(inputs=concat_out,
                                  filters=32,
                                  kernel_size=3,
                                  dilation_rate=4,
                                  name='refine_concat_res0_' + scale_name)
        concat_out = resnet_block(inputs=concat_out,
                                  filters=32,
                                  kernel_size=3,
                                  dilation_rate=8,
                                  name='refine_concat_res1_' + scale_name)
        concat_out = resnet_block(inputs=concat_out,
                                  filters=32,
                                  kernel_size=3,
                                  dilation_rate=1,
                                  name='refine_concat_res2_' + scale_name)
        concat_out = resnet_block(inputs=concat_out,
                                  filters=32,
                                  kernel_size=3,
                                  dilation_rate=1,
                                  name='refine_concat_res3_' + scale_name)
        #residual to add to the refocus image
        blur_image_residual = conv2d(inputs=concat_out,
                                     filters=3,
                                     kernel_size=3,
                                     name='refine_concat_conv1_' + scale_name)
        #normalise it
        blur_image_residual = tf.nn.tanh(blur_image_residual)

        #add it and return (note: normalised)
        high_res_blur_image = upsampled_blur_image + blur_image_residual

        #clipping?
        #high_res_blur_image = tf.clip_by_value(high_res_blur_image, 0, 1)

        return high_res_blur_image
Exemple #5
0
def _disparity_map_refinement_net(img, upsampled_disparity_map, scale_name,
                                  is_training):
    """
    Step to upsample and refine a disparity map, aided by the input image.
    """
    with tf.variable_scope("guided_upsampling_scale_" + scale_name):
        #
        disparity_out = conv2d(inputs=upsampled_disparity_map,
                               filters=16,
                               kernel_size=3,
                               name='refine_disparity_conv1_' + scale_name)
        disparity_out = tf.layers.batch_normalization(
            inputs=disparity_out,
            training=is_training,
            name='refine_disparity_bn1_' + scale_name)
        disparity_out = tf.nn.leaky_relu(features=disparity_out,
                                         name='refine_disparity_leaky1_' +
                                         scale_name)
        disparity_out = resnet_block(inputs=disparity_out,
                                     filters=16,
                                     kernel_size=3,
                                     dilation_rate=1,
                                     name='refine_disparity_res1_' +
                                     scale_name)
        disparity_out = resnet_block(inputs=disparity_out,
                                     filters=16,
                                     kernel_size=3,
                                     dilation_rate=2,
                                     name='refine_disparity_res2_' +
                                     scale_name)
        #
        image_out = conv2d(inputs=img,
                           filters=16,
                           kernel_size=3,
                           name='refine_image_conv1_' + scale_name)
        image_out = tf.layers.batch_normalization(inputs=image_out,
                                                  training=is_training,
                                                  name='refine_image_bn1_' +
                                                  scale_name)
        image_out = tf.nn.leaky_relu(features=image_out,
                                     name='refine_image_leaky1_' + scale_name)
        image_out = resnet_block(inputs=image_out,
                                 filters=16,
                                 kernel_size=3,
                                 dilation_rate=1,
                                 name='refine_image_res1_' + scale_name)
        image_out = resnet_block(inputs=image_out,
                                 filters=16,
                                 kernel_size=3,
                                 dilation_rate=2,
                                 name='refine_image_res2_' + scale_name)
        #
        concat_out = tf.concat([image_out, disparity_out],
                               axis=3,
                               name='refine_concat_' + scale_name)
        #
        concat_out = resnet_block(inputs=concat_out,
                                  filters=32,
                                  kernel_size=3,
                                  dilation_rate=4,
                                  name='refine_concat_res0_' + scale_name)
        concat_out = resnet_block(inputs=concat_out,
                                  filters=32,
                                  kernel_size=3,
                                  dilation_rate=8,
                                  name='refine_concat_res1_' + scale_name)
        concat_out = resnet_block(inputs=concat_out,
                                  filters=32,
                                  kernel_size=3,
                                  dilation_rate=1,
                                  name='refine_concat_res2_' + scale_name)
        concat_out = resnet_block(inputs=concat_out,
                                  filters=32,
                                  kernel_size=3,
                                  dilation_rate=1,
                                  name='refine_concat_res3_' + scale_name)
        #
        disparity_residual = conv2d(inputs=concat_out,
                                    filters=1,
                                    kernel_size=3,
                                    name='refine_concat_conv1_' + scale_name)
        #
        high_res_disparity_map = upsampled_disparity_map + disparity_residual

        return high_res_disparity_map