예제 #1
0
def predict(image_path):
    # 先读入图像
    image = Image.open(image_path)
    # 对图像resize到416,416,3
    image = image.resize((416, 416))
    image_copy = image.copy()
    image_copy = np.array(image_copy)
    # 图像转为array
    image = np.array(image)
    # 图像归一化
    image = image / 255
    # 弄成批量
    image = image.reshape(1, 416, 416, 3)

    # 初始化模型
    model = mobilenet_segnet(classes_num, 416, 416)
    # 模型加载权重
    model.load_weights(model_weights_path)

    # 进行预测,result = (43264,2)
    result = model.predict(image)[0]
    # (43264,2->208,208,2)
    result = result.reshape(208, 208, classes_num)
    # (208,208,2)->(208,208)
    result = result.argmax(axis=-1)

    seg_img = np.zeros((208, 208, 3))
    for c in range(classes_num):
        seg_img[:, :, 0] += ((result == c) * colors[c][0]).astype('uint8')
        seg_img[:, :, 1] += ((result == c) * colors[c][1]).astype('uint8')
        seg_img[:, :, 1] += ((result == c) * colors[c][2]).astype('uint8')

    seg_img = Image.fromarray(np.uint8(seg_img)).resize(
        (image_copy.shape[0], image_copy.shape[1]))

    image_copy = Image.fromarray(image_copy)
    image = Image.blend(image_copy, seg_img, 0.5)
    image.show()
예제 #2
0
    class_colors = [[0, 0, 0], [0, 255, 0]]
    #---------------------------------------------#
    #   定义输入图片的高和宽,以及种类数量
    #---------------------------------------------#
    HEIGHT = 416
    WIDTH = 416
    #---------------------------------------------#
    #   背景 + 斑马线 = 2
    #---------------------------------------------#
    NCLASSES = 2

    #---------------------------------------------#
    #   载入模型
    #---------------------------------------------#
    model = mobilenet_segnet(n_classes=NCLASSES,
                             input_height=HEIGHT,
                             input_width=WIDTH)
    #--------------------------------------------------#
    #   载入权重,训练好的权重会保存在logs文件夹里面
    #   我们需要将对应的权重载入
    #   修改model_path,将其对应我们训练好的权重即可
    #   下面只是一个示例
    #--------------------------------------------------#
    model.load_weights("logs/ep030-loss0.007-val_loss0.024.h5")

    #--------------------------------------------------#
    #   对imgs文件夹进行一个遍历
    #--------------------------------------------------#
    imgs = os.listdir("./img/")
    for jpg in imgs:
        #--------------------------------------------------#
예제 #3
0
#---------------------------------------------#
#   该部分用于查看网络结构
#---------------------------------------------#
from nets.segnet import mobilenet_segnet

if __name__ == "__main__":
    model = mobilenet_segnet(2, input_height=416, input_width=416)
    model.summary()