def inference(image, keep_prob): print("setting up vgg initialized conv layers ...") model_data = utils.get_model_data(FLAGS.model_dir, MODEL_URL) mean = model_data['normalization'][0][0][0] mean_pixel = np.mean(mean, axis=(0, 1)) weights = np.squeeze(model_data['layers']) processed_image = utils.process_image(image, mean_pixel) with tf.variable_scope("inference"): image_net, pool_argmax = vgg_net(weights, processed_image) conv_final_layer = image_net["pool5"] fc_6 = utils.conv_layer(x=conv_final_layer, W_shape=[7, 7, 512, 4096], b_shape=4096, name='fc_6', padding='SAME') fc_7 = utils.conv_layer(x=fc_6, W_shape=[1, 1, 4096, 4096], b_shape=4096, name='fc_7', padding='SAME') with tf.variable_scope("Deconv"): deconv_fc_6 = utils.deconv_layer(fc_7, [7, 7, 512, 4096], 512, 'fc6_deconv') #unpool_5 = utils.unpool_layer2x2(deconv_fc_6, pool_argmax[-1], tf.shape(image_net["conv5_3"])) unpool_5 = utils.unpool_layer2x2_batch( deconv_fc_6, pool_argmax[-1] ) # Use unpool_layer2x2_batch if the input image is a batch deconv_5_3 = utils.deconv_layer(unpool_5, [3, 3, 512, 512], 512, 'deconv_5_3') deconv_5_2 = utils.deconv_layer(deconv_5_3, [3, 3, 512, 512], 512, 'deconv_5_2') deconv_5_1 = utils.deconv_layer(deconv_5_2, [3, 3, 512, 512], 512, 'deconv_5_1') #unpool_4 = utils.unpool_layer2x2(deconv_5_1, pool_argmax[-2], tf.shape(image_net["conv4_3"])) unpool_4 = utils.unpool_layer2x2_batch(deconv_5_1, pool_argmax[-2]) deconv_4_3 = utils.deconv_layer(unpool_4, [3, 3, 512, 512], 512, 'deconv_4_3') deconv_4_2 = utils.deconv_layer(deconv_4_3, [3, 3, 512, 512], 512, 'deconv_4_2') deconv_4_1 = utils.deconv_layer(deconv_4_2, [3, 3, 256, 512], 256, 'deconv_4_1') #unpool_3 = utils.unpool_layer2x2(deconv_4_1, pool_argmax[-3], tf.shape(image_net["conv3_3"])) unpool_3 = utils.unpool_layer2x2_batch(deconv_4_1, pool_argmax[-3]) deconv_3_3 = utils.deconv_layer(unpool_3, [3, 3, 256, 256], 256, 'deconv_3_3') deconv_3_2 = utils.deconv_layer(deconv_3_3, [3, 3, 256, 256], 256, 'deconv_3_2') deconv_3_1 = utils.deconv_layer(deconv_3_2, [3, 3, 128, 256], 128, 'deconv_3_1') #unpool_2 = utils.unpool_layer2x2(deconv_3_1, pool_argmax[-4], tf.shape(image_net["conv2_2"])) unpool_2 = utils.unpool_layer2x2_batch(deconv_3_1, pool_argmax[-4]) deconv_2_2 = utils.deconv_layer(unpool_2, [3, 3, 128, 128], 128, 'deconv_2_2') deconv_2_1 = utils.deconv_layer(deconv_2_2, [3, 3, 64, 128], 64, 'deconv_2_1') #unpool_1 = utils.unpool_layer2x2(deconv_2_1, pool_argmax[-5], tf.shape(image_net["conv1_2"])) unpool_1 = utils.unpool_layer2x2_batch(deconv_2_1, pool_argmax[-5]) deconv_1_2 = utils.deconv_layer(unpool_1, [3, 3, 64, 64], 64, 'deconv_1_2') deconv_1_1 = utils.deconv_layer(deconv_1_2, [3, 3, 32, 64], 32, 'deconv_1_1') score_1 = utils.deconv_layer(deconv_1_1, [1, 1, 21, 32], 21, 'score_1') logits = tf.reshape(score_1, (-1, 21)) prediction = tf.argmax(tf.reshape(tf.nn.softmax(logits), tf.shape(score_1)), dimension=3) print(prediction.shape, logits.shape, deconv_1_2.shape, deconv_1_1.shape) return prediction, logits