def inference_deconvolution(image, keep_prob, training): """ FCN with deconvolution decoder definition :param image: input image. Should have values in range 0-255 :param keep_prob: drop-out argument """ print("setting up initialized conv layers ...") net = {}; num_conv = 4; height = 7; width = 3; depth = 2; height1 = 7; width1 = 3; depth1 = 2; inpt = image; output_depth = 64; input_depth = inpt.get_shape().as_list()[4] W_t = tf.Variable(tf.truncated_normal([height, width, depth, input_depth, output_depth], stddev=0.02)) tf.add_to_collection('lr_w', W_t) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t)) output = utils.conv3d_strided_batch_normalization(inpt, W_t, training, stride=1) net['conv_batch_normalization1'] = output for i in range(1, num_conv): output = tf.nn.avg_pool3d(output, ksize=[1, 2, 2, 2, 1], strides=[1, 2, 2, 2, 1], padding='SAME') if i == (num_conv - 1): height_mlp = 1; width_mlp = 1; depth_mlp = 1; output_mlp = 8 W_t = tf.Variable( tf.truncated_normal([height_mlp, width_mlp, depth_mlp, output.get_shape().as_list()[4], output_mlp], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t)) tf.add_to_collection('lr_w', W_t) output = utils.conv3d_strided_batch_normalization(output, W_t, training, stride=1) output_mlp = 64 W_t = tf.Variable( tf.truncated_normal([height_mlp, width_mlp, depth_mlp, output.get_shape().as_list()[4], output_mlp], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t)) tf.add_to_collection('lr_w', W_t) output = utils.conv3d_strided_batch_normalization(output, W_t, training, stride=1) net['pool%s' % str(i)] = output inpt = output; output_depth = 64; input_depth = inpt.get_shape().as_list()[4] W_t = tf.Variable(tf.truncated_normal([height, width, depth, input_depth, output_depth], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t)) tf.add_to_collection('lr_w', W_t) output = utils.conv3d_strided_batch_normalization(inpt, W_t, training, stride=1) net['conv_batch_normalization%s' % str(i + 1)] = output last_output_num = 64 # now to upscale to actual image size deconv_shape1 = net["conv_batch_normalization3"].get_shape() W_t1 = tf.Variable( tf.truncated_normal([height1, width1, depth1, deconv_shape1[4].value, last_output_num], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t1)) tf.add_to_collection('lr_w', W_t1) # b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1") conv_t1 = utils.conv3d_transpose_strided_batch_normalization(output, W_t1, training, output_shape=tf.shape( net["conv_batch_normalization3"]), stride=2) try: fuse_1 = tf.concat((conv_t1, net["conv_batch_normalization3"]), 4, name="fuse_1") except Exception as e: fuse_1 = tf.concat(4, (conv_t1, net["conv_batch_normalization3"]), name="fuse_1") deconv_shape2 = net["conv_batch_normalization2"].get_shape() W_t2 = tf.Variable( tf.truncated_normal([height1, width1, depth1, deconv_shape2[4].value, 2 * deconv_shape1[4].value], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t2)) tf.add_to_collection('lr_w', W_t2) # b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2") conv_t2 = utils.conv3d_transpose_strided_batch_normalization(fuse_1, W_t2, training, output_shape=tf.shape( net["conv_batch_normalization2"]), stride=2) try: fuse_2 = tf.concat((conv_t2, net["conv_batch_normalization2"]), 4, name="fuse_2") except Exception as e: fuse_2 = tf.concat(4, (conv_t2, net["conv_batch_normalization2"]), name="fuse_2") deconv_shape3 = net["conv_batch_normalization1"].get_shape() W_t3 = tf.Variable( tf.truncated_normal([height1, width1, depth1, deconv_shape3[4].value, 2 * deconv_shape2[4].value], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t3)) tf.add_to_collection('lr_w', W_t3) # b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2") conv_t3 = utils.conv3d_transpose_strided_batch_normalization(fuse_2, W_t3, training, output_shape=tf.shape( net["conv_batch_normalization1"]), stride=2) try: fuse_3 = tf.concat((conv_t3, net["conv_batch_normalization1"]), 4, name="fuse_3") except Exception as e: fuse_3 = tf.concat(4, (conv_t3, net["conv_batch_normalization1"]), name="fuse_3") W_t4 = tf.Variable( tf.truncated_normal([height1, width1, depth1, 2 * deconv_shape3[4].value, last_output_num], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t4)) tf.add_to_collection('lr_w', W_t4) # b_t4 = tf.Variable(tf.constant(0.0, shape=[last_output_num])) # tf.add_to_collection('lr_b', b_t4) conv_t4 = utils.conv3d_strided_batch_normalization(fuse_3, W_t4, training, stride=1) conv_t4 = tf.nn.dropout(conv_t4, keep_prob=keep_prob) W_4 = tf.Variable(tf.truncated_normal([1, 1, 1, last_output_num, NUM_OF_CLASSESS], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_4)) tf.add_to_collection('lr_w', W_4) b_4 = tf.Variable(tf.constant(0.0, shape=[NUM_OF_CLASSESS])) tf.add_to_collection('lr_b', b_4) conv_t4 = tf.nn.conv3d(conv_t4, filter=W_4, strides=[1, 1, 1, 1, 1], padding="SAME") conv_t4 = tf.nn.bias_add(conv_t4, b_4) weight_decay_sum = tf.add_n(tf.get_collection('weight_decay')) lr_w_vars = tf.get_collection('lr_w') lr_b_vars = tf.get_collection('lr_b') return conv_t4, weight_decay_sum, lr_w_vars, lr_b_vars
def inference_median_unpool(image, keep_prob, training): """ FCN with median unpooling decoder definition """ print("setting up initialized conv layers ...") net = {}; num_conv = 4; height = 7; width = 3; depth = 3; height1 = 7; width1 = 3; depth1 = 3; pool_sz = [2, 2, 2] inpt = image; output_depth = 64; input_depth = inpt.get_shape().as_list()[4] W_t = tf.Variable(tf.truncated_normal([height, width, depth, input_depth, output_depth], stddev=0.02)) tf.add_to_collection('lr_w', W_t) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t)) output = utils.conv3d_strided_batch_normalization(inpt, W_t, training, stride=1) net['conv_batch_normalization1'] = output for i in range(1, num_conv): print(output.get_shape()) output, argmax = utils.median_pool_3d_with_argmedian(output, pool_sz) print(argmax.get_shape()) net['pool%s' % str(i)] = output net['pool_argmax%s' % str(i)] = argmax inpt = output; output_depth = 64; input_depth = inpt.get_shape().as_list()[4] W_t = tf.Variable(tf.truncated_normal([height, width, depth, input_depth, output_depth], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t)) tf.add_to_collection('lr_w', W_t) output = utils.conv3d_strided_batch_normalization(inpt, W_t, training, stride=1) net['conv_batch_normalization%s' % str(i + 1)] = output last_output_num = 64 # now to upscale to actual image size unpool1 = utils.unpool_layer_batch_unraveled_indices(output, net['pool_argmax3'], pool_sz) try: fuse_1 = tf.concat((unpool1, net["conv_batch_normalization3"]), 4, name="fuse_1") except Exception as e: fuse_1 = tf.concat(4, (unpool1, net["conv_batch_normalization3"]), name="fuse_1") deconv_shape1 = net["conv_batch_normalization3"].get_shape() W_t1 = tf.Variable( tf.truncated_normal([height1, width1, depth1, 2 * deconv_shape1[4].value, last_output_num], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t1)) tf.add_to_collection('lr_w', W_t1) # b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1") conv_t1 = utils.conv3d_strided_batch_normalization(fuse_1, W_t1, training, stride=1) unpool2 = utils.unpool_layer_batch_unraveled_indices(conv_t1, net['pool_argmax2'], pool_sz) try: fuse_2 = tf.concat((unpool2, net["conv_batch_normalization2"]), 4, name="fuse_2") except Exception as e: fuse_2 = tf.concat(4, (unpool2, net["conv_batch_normalization2"]), name="fuse_2") deconv_shape2 = net["conv_batch_normalization2"].get_shape() W_t2 = tf.Variable( tf.truncated_normal([height1, width1, depth1, 2 * deconv_shape1[4].value, deconv_shape2[4].value, ], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t2)) tf.add_to_collection('lr_w', W_t2) # b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2") conv_t2 = utils.conv3d_strided_batch_normalization(fuse_2, W_t2, training, stride=1) unpool3 = utils.unpool_layer_batch_unraveled_indices(conv_t2, net['pool_argmax1'], pool_sz) deconv_shape3 = net["conv_batch_normalization1"].get_shape() try: fuse_3 = tf.concat((unpool3, net["conv_batch_normalization1"]), 4, name="fuse_3") except Exception as e: fuse_3 = tf.concat(4, (unpool3, net["conv_batch_normalization1"]), name="fuse_3") W_t3 = tf.Variable( tf.truncated_normal([height1, width1, depth1, 2 * deconv_shape2[4].value, deconv_shape3[4].value, ], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t3)) tf.add_to_collection('lr_w', W_t3) # b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2") conv_t3 = utils.conv3d_strided_batch_normalization(fuse_3, W_t3, training, stride=1) # W_t4 = tf.Variable(tf.truncated_normal([height1, width1, depth1, 2*deconv_shape3[4].value, last_output_num], stddev=0.02)) # tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_t3)) # tf.add_to_collection('lr_w', W_t4) # #b_t4 = tf.Variable(tf.constant(0.0, shape=[last_output_num])) # #tf.add_to_collection('lr_b', b_t4) # conv_t4 = utils.conv3d_strided_batch_normalization(conv_t3, W_t4, training, stride=1) # conv_t4 = tf.nn.dropout(conv_t4, keep_prob=keep_prob) W_4 = tf.Variable(tf.truncated_normal([1, 1, 1, last_output_num, NUM_OF_CLASSESS], stddev=0.02)) tf.add_to_collection('weight_decay', tf.nn.l2_loss(W_4)) tf.add_to_collection('lr_w', W_4) b_4 = tf.Variable(tf.constant(0.0, shape=[NUM_OF_CLASSESS])) tf.add_to_collection('lr_b', b_4) conv_t4 = tf.nn.conv3d(conv_t3, filter=W_4, strides=[1, 1, 1, 1, 1], padding="SAME") conv_t4 = tf.nn.bias_add(conv_t4, b_4) weight_decay_sum = tf.add_n(tf.get_collection('weight_decay')) lr_w_vars = tf.get_collection('lr_w') lr_b_vars = tf.get_collection('lr_b') return conv_t3, conv_t4, weight_decay_sum, lr_w_vars, lr_b_vars