Ejemplo n.º 1
0
def inference_deconvolution(image, keep_prob, training):
    """
    FCN with deconvolution decoder definition
    :param image: input image. Should have values in range 0-255
    :param keep_prob: drop-out argument
    """
    print("setting up initialized conv layers ...")
    net = {};
    num_conv = 4;
    height = 7;
    width = 3;
    depth = 2;
    height1 = 7;
    width1 = 3;
    depth1 = 2;
    inpt = image;
    output_depth = 64;
    input_depth = inpt.get_shape().as_list()[4]
    W_t = tf.Variable(tf.truncated_normal([height, width, depth, input_depth, output_depth], stddev=0.02))
    tf.add_to_collection('lr_w', W_t)
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t))
    output = utils.conv3d_strided_batch_normalization(inpt, W_t, training, stride=1)
    net['conv_batch_normalization1'] = output
    for i in range(1, num_conv):
        output = tf.nn.avg_pool3d(output, ksize=[1, 2, 2, 2, 1], strides=[1, 2, 2, 2, 1], padding='SAME')
        if i == (num_conv - 1):
            height_mlp = 1;
            width_mlp = 1;
            depth_mlp = 1;
            output_mlp = 8
            W_t = tf.Variable(
                tf.truncated_normal([height_mlp, width_mlp, depth_mlp, output.get_shape().as_list()[4], output_mlp],
                                    stddev=0.02))
            tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t))
            tf.add_to_collection('lr_w', W_t)
            output = utils.conv3d_strided_batch_normalization(output, W_t, training, stride=1)
            output_mlp = 64
            W_t = tf.Variable(
                tf.truncated_normal([height_mlp, width_mlp, depth_mlp, output.get_shape().as_list()[4], output_mlp],
                                    stddev=0.02))
            tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t))
            tf.add_to_collection('lr_w', W_t)
            output = utils.conv3d_strided_batch_normalization(output, W_t, training, stride=1)
        net['pool%s' % str(i)] = output
        inpt = output;
        output_depth = 64;
        input_depth = inpt.get_shape().as_list()[4]
        W_t = tf.Variable(tf.truncated_normal([height, width, depth, input_depth, output_depth], stddev=0.02))
        tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t))
        tf.add_to_collection('lr_w', W_t)
        output = utils.conv3d_strided_batch_normalization(inpt, W_t, training, stride=1)
        net['conv_batch_normalization%s' % str(i + 1)] = output

    last_output_num = 64
    # now to upscale to actual image size
    deconv_shape1 = net["conv_batch_normalization3"].get_shape()
    W_t1 = tf.Variable(
        tf.truncated_normal([height1, width1, depth1, deconv_shape1[4].value, last_output_num], stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t1))
    tf.add_to_collection('lr_w', W_t1)
    # b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1")
    conv_t1 = utils.conv3d_transpose_strided_batch_normalization(output, W_t1, training, output_shape=tf.shape(
        net["conv_batch_normalization3"]), stride=2)
    try:
        fuse_1 = tf.concat((conv_t1, net["conv_batch_normalization3"]), 4, name="fuse_1")
    except Exception as e:
        fuse_1 = tf.concat(4, (conv_t1, net["conv_batch_normalization3"]), name="fuse_1")
    deconv_shape2 = net["conv_batch_normalization2"].get_shape()
    W_t2 = tf.Variable(
        tf.truncated_normal([height1, width1, depth1, deconv_shape2[4].value, 2 * deconv_shape1[4].value], stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t2))
    tf.add_to_collection('lr_w', W_t2)
    # b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
    conv_t2 = utils.conv3d_transpose_strided_batch_normalization(fuse_1, W_t2, training, output_shape=tf.shape(
        net["conv_batch_normalization2"]), stride=2)
    try:
        fuse_2 = tf.concat((conv_t2, net["conv_batch_normalization2"]), 4, name="fuse_2")
    except Exception as e:
        fuse_2 = tf.concat(4, (conv_t2, net["conv_batch_normalization2"]), name="fuse_2")

    deconv_shape3 = net["conv_batch_normalization1"].get_shape()
    W_t3 = tf.Variable(
        tf.truncated_normal([height1, width1, depth1, deconv_shape3[4].value, 2 * deconv_shape2[4].value], stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t3))
    tf.add_to_collection('lr_w', W_t3)
    # b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
    conv_t3 = utils.conv3d_transpose_strided_batch_normalization(fuse_2, W_t3, training, output_shape=tf.shape(
        net["conv_batch_normalization1"]), stride=2)
    try:
        fuse_3 = tf.concat((conv_t3, net["conv_batch_normalization1"]), 4, name="fuse_3")
    except Exception as e:
        fuse_3 = tf.concat(4, (conv_t3, net["conv_batch_normalization1"]), name="fuse_3")

    W_t4 = tf.Variable(
        tf.truncated_normal([height1, width1, depth1, 2 * deconv_shape3[4].value, last_output_num], stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t4))
    tf.add_to_collection('lr_w', W_t4)
    # b_t4 = tf.Variable(tf.constant(0.0, shape=[last_output_num]))
    # tf.add_to_collection('lr_b', b_t4)
    conv_t4 = utils.conv3d_strided_batch_normalization(fuse_3, W_t4, training, stride=1)

    conv_t4 = tf.nn.dropout(conv_t4, keep_prob=keep_prob)
    W_4 = tf.Variable(tf.truncated_normal([1, 1, 1, last_output_num, NUM_OF_CLASSESS], stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_4))
    tf.add_to_collection('lr_w', W_4)
    b_4 = tf.Variable(tf.constant(0.0, shape=[NUM_OF_CLASSESS]))
    tf.add_to_collection('lr_b', b_4)
    conv_t4 = tf.nn.conv3d(conv_t4, filter=W_4, strides=[1, 1, 1, 1, 1], padding="SAME")
    conv_t4 = tf.nn.bias_add(conv_t4, b_4)

    weight_decay_sum = tf.add_n(tf.get_collection('weight_decay'))
    lr_w_vars = tf.get_collection('lr_w')
    lr_b_vars = tf.get_collection('lr_b')

    return conv_t4, weight_decay_sum, lr_w_vars, lr_b_vars
Ejemplo n.º 2
0
def inference_median_unpool(image, keep_prob, training):
    """
    FCN with median unpooling decoder definition

    """
    print("setting up initialized conv layers ...")
    net = {};
    num_conv = 4;
    height = 7;
    width = 3;
    depth = 3;
    height1 = 7;
    width1 = 3;
    depth1 = 3;
    pool_sz = [2, 2, 2]
    inpt = image;
    output_depth = 64;
    input_depth = inpt.get_shape().as_list()[4]
    W_t = tf.Variable(tf.truncated_normal([height, width, depth, input_depth, output_depth], stddev=0.02))
    tf.add_to_collection('lr_w', W_t)
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t))
    output = utils.conv3d_strided_batch_normalization(inpt, W_t, training, stride=1)
    net['conv_batch_normalization1'] = output
    for i in range(1, num_conv):
        print(output.get_shape())
        output, argmax = utils.median_pool_3d_with_argmedian(output, pool_sz)
        print(argmax.get_shape())
        net['pool%s' % str(i)] = output
        net['pool_argmax%s' % str(i)] = argmax
        inpt = output;
        output_depth = 64;
        input_depth = inpt.get_shape().as_list()[4]
        W_t = tf.Variable(tf.truncated_normal([height, width, depth, input_depth, output_depth], stddev=0.02))
        tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t))
        tf.add_to_collection('lr_w', W_t)
        output = utils.conv3d_strided_batch_normalization(inpt, W_t, training, stride=1)
        net['conv_batch_normalization%s' % str(i + 1)] = output

    last_output_num = 64
    # now to upscale to actual image size
    unpool1 = utils.unpool_layer_batch_unraveled_indices(output, net['pool_argmax3'], pool_sz)
    try:
        fuse_1 = tf.concat((unpool1, net["conv_batch_normalization3"]), 4, name="fuse_1")
    except Exception as e:
        fuse_1 = tf.concat(4, (unpool1, net["conv_batch_normalization3"]), name="fuse_1")
    deconv_shape1 = net["conv_batch_normalization3"].get_shape()
    W_t1 = tf.Variable(
        tf.truncated_normal([height1, width1, depth1, 2 * deconv_shape1[4].value, last_output_num], stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t1))
    tf.add_to_collection('lr_w', W_t1)
    # b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1")
    conv_t1 = utils.conv3d_strided_batch_normalization(fuse_1, W_t1, training, stride=1)

    unpool2 = utils.unpool_layer_batch_unraveled_indices(conv_t1, net['pool_argmax2'], pool_sz)
    try:
        fuse_2 = tf.concat((unpool2, net["conv_batch_normalization2"]), 4, name="fuse_2")
    except Exception as e:
        fuse_2 = tf.concat(4, (unpool2, net["conv_batch_normalization2"]), name="fuse_2")
    deconv_shape2 = net["conv_batch_normalization2"].get_shape()
    W_t2 = tf.Variable(
        tf.truncated_normal([height1, width1, depth1, 2 * deconv_shape1[4].value, deconv_shape2[4].value, ],
                            stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t2))
    tf.add_to_collection('lr_w', W_t2)
    # b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
    conv_t2 = utils.conv3d_strided_batch_normalization(fuse_2, W_t2, training, stride=1)

    unpool3 = utils.unpool_layer_batch_unraveled_indices(conv_t2, net['pool_argmax1'], pool_sz)
    deconv_shape3 = net["conv_batch_normalization1"].get_shape()
    try:
        fuse_3 = tf.concat((unpool3, net["conv_batch_normalization1"]), 4, name="fuse_3")
    except Exception as e:
        fuse_3 = tf.concat(4, (unpool3, net["conv_batch_normalization1"]), name="fuse_3")
    W_t3 = tf.Variable(
        tf.truncated_normal([height1, width1, depth1, 2 * deconv_shape2[4].value, deconv_shape3[4].value, ],
                            stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t3))
    tf.add_to_collection('lr_w', W_t3)
    # b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
    conv_t3 = utils.conv3d_strided_batch_normalization(fuse_3, W_t3, training, stride=1)

    #    W_t4 = tf.Variable(tf.truncated_normal([height1, width1, depth1, 2*deconv_shape3[4].value, last_output_num], stddev=0.02))
    #    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t3))
    #    tf.add_to_collection('lr_w', W_t4)
    #    #b_t4 = tf.Variable(tf.constant(0.0, shape=[last_output_num]))
    #    #tf.add_to_collection('lr_b', b_t4)
    #    conv_t4 = utils.conv3d_strided_batch_normalization(conv_t3, W_t4, training, stride=1)

    #    conv_t4 = tf.nn.dropout(conv_t4, keep_prob=keep_prob)
    W_4 = tf.Variable(tf.truncated_normal([1, 1, 1, last_output_num, NUM_OF_CLASSESS], stddev=0.02))
    tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_4))
    tf.add_to_collection('lr_w', W_4)
    b_4 = tf.Variable(tf.constant(0.0, shape=[NUM_OF_CLASSESS]))
    tf.add_to_collection('lr_b', b_4)

    conv_t4 = tf.nn.conv3d(conv_t3, filter=W_4, strides=[1, 1, 1, 1, 1], padding="SAME")
    conv_t4 = tf.nn.bias_add(conv_t4, b_4)

    weight_decay_sum = tf.add_n(tf.get_collection('weight_decay'))
    lr_w_vars = tf.get_collection('lr_w')
    lr_b_vars = tf.get_collection('lr_b')
    return conv_t3, conv_t4, weight_decay_sum, lr_w_vars, lr_b_vars