コード例 #1
0
            Y_train.append(seg_labels)

            # 读完一个周期后重新开始
            i = (i + 1) % n
        yield (np.array(X_train), np.array(Y_train))


def loss(y_true, y_pred):
    loss = K.categorical_crossentropy(y_true, y_pred)
    return loss


if __name__ == "__main__":
    log_dir = "logs/"
    # 获取model
    model = Deeplabv3(classes=NCLASSES, input_shape=(HEIGHT, WIDTH, 3))
    # model.summary()

    weights_path = get_file(
        'deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5',
        WEIGHTS_PATH_MOBILE,
        cache_subdir='models')
    # model.load_weights(weights_path,by_name=True,skip_mismatch=True)
    model.load_weights(weights_path, by_name=True, skip_mismatch=True)
    # 打开数据集的txt
    with open(r".\dataset2\train_data.txt", "r") as f:
        lines = f.readlines()

    # 打乱行,这个txt主要用于帮助读取数据来训练
    # 打乱的数据更有利于训练
    np.random.seed(10101)
コード例 #2
0
from nets.deeplab import Deeplabv3
from PIL import Image
import numpy as np
import random
import copy
import os

class_colors = [[0, 0, 0], [0, 255, 0]]
NCLASSES = 2
HEIGHT = 416
WIDTH = 416

model = model = Deeplabv3(classes=2, input_shape=(HEIGHT, WIDTH, 3))
model.load_weights("logs/last1.h5")
imgs = os.listdir("./building_img")

for jpg in imgs:

    img = Image.open("./building_img/" + jpg)
    old_img = copy.deepcopy(img)
    orininal_h = np.array(img).shape[0]
    orininal_w = np.array(img).shape[1]

    img = img.resize((WIDTH, HEIGHT))
    img = np.array(img)
    img = img / 255
    img = img.reshape(-1, HEIGHT, WIDTH, 3)
    pr = model.predict(img)[0]

    pr = pr.reshape((int(HEIGHT), int(WIDTH), NCLASSES)).argmax(axis=-1)
コード例 #3
0
from nets.deeplab import Deeplabv3
model = Deeplabv3(input_shape=(512 * 2, 512 * 2, 3), classes=2, OS=16)
model.summary()
コード例 #4
0
ファイル: test.py プロジェクト: JamesXiaoFF/deeplab
from nets.deeplab import Deeplabv3
from keras.utils import plot_model

deeplab_model = Deeplabv3()
deeplab_model.summary()
コード例 #5
0
#---------------------------------------------#
#   该部分用于查看网络结构
#---------------------------------------------#
from nets.deeplab import Deeplabv3

if __name__ == "__main__":
    model = Deeplabv3(classes=2, OS=16)
    model.summary()
    for i in range(len(model.layers)):
        print(i, model.layers[i].name)