예제 #1
0
    def generate(self):
        model_path = os.path.expanduser(self.model_path)
        assert model_path.endswith(
            '.h5'), 'Keras model or weights must be a .h5 file.'

        # 计算总的种类
        self.num_classes = len(self.class_names)

        # 载入模型,如果原来的模型里已经包括了模型结构则直接载入。
        # 否则先构建模型再载入
        inputs = Input(self.model_image_size)
        self.retinanet_model = retinanet.resnet_retinanet(
            self.num_classes, inputs)
        self.retinanet_model.load_weights(self.model_path, by_name=True)

        self.retinanet_model.summary()
        print('{} model, anchors, and classes loaded.'.format(model_path))

        # 画框设置不同的颜色
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))
예제 #2
0
    def generate(self):
        model_path = os.path.expanduser(self.model_path)
        assert model_path.endswith(
            '.h5'), 'Keras model or weights must be a .h5 file.'
        #------------------#
        #   计算种类数量
        #------------------#
        self.num_classes = len(self.class_names)

        #------------------#
        #   载入模型
        #------------------#
        self.retinanet_model = retinanet.resnet_retinanet(
            self.num_classes, self.model_image_size)
        self.retinanet_model.load_weights(self.model_path, by_name=True)

        print('{} model, anchors, and classes loaded.'.format(model_path))

        # 画框设置不同的颜色
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))
예제 #3
0
import nets.retinanet as retinanet
import numpy as np
import keras
from keras.optimizers import Adam
from nets.retinanet_training import Generator
from nets.retinanet_training import focal, smooth_l1
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from utils.utils import BBoxUtility
from utils.anchors import get_anchors

if __name__ == "__main__":
    NUM_CLASSES = 20
    input_shape = (600, 600, 3)
    annotation_path = '2007_train.txt'
    inputs = keras.layers.Input(shape=input_shape)
    model = retinanet.resnet_retinanet(NUM_CLASSES, inputs)
    priors = get_anchors(model)
    bbox_util = BBoxUtility(NUM_CLASSES, priors)

    #-------------------------------------------#
    #   权值文件的下载请看README
    #-------------------------------------------#
    model.load_weights("model_data/resnet50_coco_best_v2.1.0.h5",
                       by_name=True,
                       skip_mismatch=True)

    # 0.1用于验证,0.9用于训练
    val_split = 0.1
    with open(annotation_path) as f:
        lines = f.readlines()
    np.random.seed(10101)
예제 #4
0
#--------------------------------------------#
#   该部分代码只用于看网络结构,并非测试代码
#   map测试请看get_dr_txt.py、get_gt_txt.py
#   和get_map.py
#--------------------------------------------#
from nets.retinanet import resnet_retinanet

if __name__ == "__main__":
    model = resnet_retinanet(80)
    model.summary()