def inference(image, keep_prob):

    #IMG_MEAN = np.array((104.00698793/255,116.66876762/255,122.67891434/255,146.01657/255), dtype=np.float32)
    #processed_image = utils.process_image(image, IMG_MEAN)

    with tf.variable_scope("seg_inference"):
        W1_1 = utils.weight_variable([3, 3, 3, 64], name="W1_1")
	b1_1 = utils.bias_variable([64], name="b1_1")
	conv1_1 = utils.conv2d_basic(image, W1_1, b1_1)
	relu1_1 = tf.nn.relu(conv1_1, name="relu1_1")

	W1_2 = utils.weight_variable([3, 3, 64, 64], name="W1_2")
	b1_2 = utils.bias_variable([64], name="b1_2")
	conv1_2 = utils.conv2d_basic(relu1_1, W1_2, b1_2)
	relu1_2 = tf.nn.relu(conv1_2, name="relu1_2")

	ra_1, ra_1_small = utils.RA_unit(relu1_2, relu1_2.shape[1].value, relu1_2.shape[2].value, 16)
	W_s1 = utils.weight_variable([3, 3, 64*(1+16), 64], name="W_s1")
        b_s1 = utils.bias_variable([64], name="b_s1")
        conv_s1 = utils.conv2d_basic(ra_1, W_s1, b_s1)
        relu_s1 = tf.nn.relu(conv_s1, name="relu_s1")
	
	pool1 = utils.max_pool_2x2(relu_s1)

	W2_1 = utils.weight_variable([3, 3, 64, 128], name="W2_1")
	b2_1 = utils.bias_variable([128], name="b2_1")
	conv2_1 = utils.conv2d_basic(pool1, W2_1, b2_1)
	relu2_1 = tf.nn.relu(conv2_1, name="relu2_1")

	W2_2 = utils.weight_variable([3, 3, 128, 128], name="W2_2")
	b2_2 = utils.bias_variable([128], name="b2_2")
	conv2_2 = utils.conv2d_basic(relu2_1, W2_2, b2_2)
	relu2_2 = tf.nn.relu(conv2_2, name="relu2_2")

	ra_2, ra_2_small = utils.RA_unit(relu2_2, relu2_2.shape[1].value, relu2_2.shape[2].value, 16)
	W_s2 = utils.weight_variable([3, 3, 128*(1+16), 128], name="W_s2")
        b_s2 = utils.bias_variable([128], name="b_s2")
        conv_s2 = utils.conv2d_basic(ra_2, W_s2, b_s2)
        relu_s2 = tf.nn.relu(conv_s2, name="relu_s2")
	
	pool2 = utils.max_pool_2x2(relu_s2)

	W3_1 = utils.weight_variable([3, 3, 128, 256], name="W3_1")
	b3_1 = utils.bias_variable([256], name="b3_1")
	conv3_1 = utils.conv2d_basic(pool2, W3_1, b3_1)
	relu3_1 = tf.nn.relu(conv3_1, name="relu3_1")

	W3_2 = utils.weight_variable([3, 3, 256, 256], name="W3_2")
	b3_2 = utils.bias_variable([256], name="b3_2")
	conv3_2 = utils.conv2d_basic(relu3_1, W3_2, b3_2)
	relu3_2 = tf.nn.relu(conv3_2, name="relu3_2")

	W3_3 = utils.weight_variable([3, 3, 256, 256], name="W3_3")
	b3_3 = utils.bias_variable([256], name="b3_3")
	conv3_3 = utils.conv2d_basic(relu3_2, W3_3, b3_3)
	relu3_3 = tf.nn.relu(conv3_3, name="relu3_3")

	ra_3, ra_3_small = utils.RA_unit(relu3_3, relu3_3.shape[1].value, relu3_3.shape[2].value, 16)
	W_s3 = utils.weight_variable([3, 3, 256*(1+16), 256], name="W_s3")
        b_s3 = utils.bias_variable([256], name="b_s3")
        conv_s3 = utils.conv2d_basic(ra_3, W_s3, b_s3)
        relu_s3 = tf.nn.relu(conv_s3, name="relu_s3")

	pool3 = utils.max_pool_2x2(relu_s3)

	W4_1 = utils.weight_variable([3, 3, 256, 512], name="W4_1")
	b4_1 = utils.bias_variable([512], name="b4_1")
	conv4_1 = utils.conv2d_basic(pool3, W4_1, b4_1)
	relu4_1 = tf.nn.relu(conv4_1, name="relu4_1")

	W4_2 = utils.weight_variable([3, 3, 512, 512], name="W4_2")
	b4_2 = utils.bias_variable([512], name="b4_2")
	conv4_2 = utils.conv2d_basic(relu4_1, W4_2, b4_2)
	relu4_2 = tf.nn.relu(conv4_2, name="relu4_2")

	W4_3 = utils.weight_variable([3, 3, 512, 512], name="W4_3")
	b4_3 = utils.bias_variable([512], name="b4_3")
	conv4_3 = utils.conv2d_basic(relu4_2, W4_3, b4_3)
	relu4_3 = tf.nn.relu(conv4_3, name="relu4_3")

	ra_4, ra_4_small = utils.RA_unit(relu4_3, relu4_3.shape[1].value, relu4_3.shape[2].value, 16)
	W_s4 = utils.weight_variable([3, 3, 512*(1+16), 512], name="W_s4")
        b_s4 = utils.bias_variable([512], name="b_s4")
        conv_s4 = utils.conv2d_basic(ra_4, W_s4, b_s4)
        relu_s4 = tf.nn.relu(conv_s4, name="relu_s4")
        
	pool4 = utils.max_pool_2x2(relu_s4)

	W5_1 = utils.weight_variable([3, 3, 512, 512], name="W5_1")
	b5_1 = utils.bias_variable([512], name="b5_1")
	conv5_1 = utils.conv2d_basic(pool4, W5_1, b5_1)
	relu5_1 = tf.nn.relu(conv5_1, name="relu5_1")

	W5_2 = utils.weight_variable([3, 3, 512, 512], name="W5_2")
	b5_2 = utils.bias_variable([512], name="b5_2")
	conv5_2 = utils.conv2d_basic(relu5_1, W5_2, b5_2)
	relu5_2 = tf.nn.relu(conv5_2, name="relu5_2")

	W5_3 = utils.weight_variable([3, 3, 512, 512], name="W5_3")
	b5_3 = utils.bias_variable([512], name="b5_3")
	conv5_3 = utils.conv2d_basic(relu5_2, W5_3, b5_3)
	relu5_3 = tf.nn.relu(conv5_3, name="relu5_3")

        ra_5, ra_5_small = utils.RA_unit(relu5_3, relu5_3.shape[1].value, relu5_3.shape[2].value, 8)
	W_s5 = utils.weight_variable([3, 3, 512*(1+8), 512], name="W_s5")
        b_s5 = utils.bias_variable([512], name="b_s5")
        conv_s5 = utils.conv2d_basic(ra_5, W_s5, b_s5)
        relu_s5 = tf.nn.relu(conv_s5, name="relu_s5")

	pool5 = utils.max_pool_2x2(relu_s5)

	W6 = utils.weight_variable([7, 7, pool5.shape[3].value, 4096], name="W6")
        b6 = utils.bias_variable([4096], name="b6")
        conv6 = utils.conv2d_basic(pool4, W6, b6)
        relu6 = tf.nn.relu(conv6, name="relu6")

        relu_dropout6 = tf.nn.dropout(relu6, keep_prob=keep_prob)

        W7 = utils.weight_variable([1, 1, 4096, 4096], name="W7")
        b7 = utils.bias_variable([4096], name="b7")
        conv7 = utils.conv2d_basic(relu_dropout6, W7, b7)
        relu7 = tf.nn.relu(conv7, name="relu7")

        relu_dropout7 = tf.nn.dropout(relu7, keep_prob=keep_prob)

        W8 = utils.weight_variable([1, 1, 4096, NUM_OF_CLASSESS], name="W8")            #in our case num_of_classess = 2 : road, non-road
        b8 = utils.bias_variable([NUM_OF_CLASSESS], name="b8")
        conv8 = utils.conv2d_basic(relu_dropout7, W8, b8)

	# now to upscale to actual image size
        deconv_shape1 = pool3.get_shape()
        W_t1 = utils.weight_variable([4, 4, deconv_shape1[3].value, NUM_OF_CLASSESS], name="W_t1")
        b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1")
        conv_t1 = utils.conv2d_transpose_strided(conv8, W_t1, b_t1, output_shape=tf.shape(pool3))
        fuse_1 = tf.add(conv_t1, pool3, name="fuse_1")

        deconv_shape2 = pool2.get_shape()
        W_t2 = utils.weight_variable([4, 4, deconv_shape2[3].value, deconv_shape1[3].value], name="W_t2")
        b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
        conv_t2 = utils.conv2d_transpose_strided(fuse_1, W_t2, b_t2, output_shape=tf.shape(pool2))
        fuse_2 = tf.add(conv_t2, pool2, name="fuse_2")
        print("fuse_2 shape:")
        print(fuse_2.shape)

        shape = tf.shape(image)
        deconv_shape3 = tf.stack([shape[0], shape[1], shape[2], NUM_OF_CLASSESS])
        W_t3 = utils.weight_variable([16, 16, NUM_OF_CLASSESS, fuse_2.shape[3].value], name="W_t3")
        b_t3 = utils.bias_variable([NUM_OF_CLASSESS], name="b_t3")
        conv_t3 = utils.conv2d_transpose_strided(fuse_2, W_t3, b_t3, output_shape=deconv_shape3, stride=4, stride_y=4)

        annotation_pred = tf.argmax(conv_t3, dimension=3, name="prediction")

    return annotation_pred, conv_t3				# conv_t3 is the finnal result