def predict(image_path): # 先读入图像 image = Image.open(image_path) # 对图像resize到416,416,3 image = image.resize((416, 416)) image_copy = image.copy() image_copy = np.array(image_copy) # 图像转为array image = np.array(image) # 图像归一化 image = image / 255 # 弄成批量 image = image.reshape(1, 416, 416, 3) # 初始化模型 model = mobilenet_segnet(classes_num, 416, 416) # 模型加载权重 model.load_weights(model_weights_path) # 进行预测,result = (43264,2) result = model.predict(image)[0] # (43264,2->208,208,2) result = result.reshape(208, 208, classes_num) # (208,208,2)->(208,208) result = result.argmax(axis=-1) seg_img = np.zeros((208, 208, 3)) for c in range(classes_num): seg_img[:, :, 0] += ((result == c) * colors[c][0]).astype('uint8') seg_img[:, :, 1] += ((result == c) * colors[c][1]).astype('uint8') seg_img[:, :, 1] += ((result == c) * colors[c][2]).astype('uint8') seg_img = Image.fromarray(np.uint8(seg_img)).resize( (image_copy.shape[0], image_copy.shape[1])) image_copy = Image.fromarray(image_copy) image = Image.blend(image_copy, seg_img, 0.5) image.show()
class_colors = [[0, 0, 0], [0, 255, 0]] #---------------------------------------------# # 定义输入图片的高和宽,以及种类数量 #---------------------------------------------# HEIGHT = 416 WIDTH = 416 #---------------------------------------------# # 背景 + 斑马线 = 2 #---------------------------------------------# NCLASSES = 2 #---------------------------------------------# # 载入模型 #---------------------------------------------# model = mobilenet_segnet(n_classes=NCLASSES, input_height=HEIGHT, input_width=WIDTH) #--------------------------------------------------# # 载入权重,训练好的权重会保存在logs文件夹里面 # 我们需要将对应的权重载入 # 修改model_path,将其对应我们训练好的权重即可 # 下面只是一个示例 #--------------------------------------------------# model.load_weights("logs/ep030-loss0.007-val_loss0.024.h5") #--------------------------------------------------# # 对imgs文件夹进行一个遍历 #--------------------------------------------------# imgs = os.listdir("./img/") for jpg in imgs: #--------------------------------------------------#
#---------------------------------------------# # 该部分用于查看网络结构 #---------------------------------------------# from nets.segnet import mobilenet_segnet if __name__ == "__main__": model = mobilenet_segnet(2, input_height=416, input_width=416) model.summary()