Exemplo n.º 1
0
    def generate(self):
        model_path = os.path.expanduser(self.model_path)
        assert model_path.endswith(
            '.h5'), 'Keras model or weights must be a .h5 file.'

        #----------------------------------------#
        #   计算种类数量
        #----------------------------------------#
        self.num_classes = len(self.class_names)

        #----------------------------------------#
        #   创建centernet模型
        #----------------------------------------#
        self.centernet = centernet(self.input_shape,
                                   num_classes=self.num_classes,
                                   backbone=self.backbone,
                                   mode='predict')
        self.centernet.load_weights(self.model_path, by_name=True)

        print('{} model, anchors, and classes loaded.'.format(self.model_path))

        # 画框设置不同的颜色
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))
Exemplo n.º 2
0
 def generate(self):
     model_path = os.path.expanduser(self.model_path)
     assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
     
     #----------------------------------------#
     #   创建centernet模型
     #----------------------------------------#
     self.centernet = centernet([self.input_shape[0], self.input_shape[1], 3], num_classes=self.num_classes, backbone=self.backbone, mode='predict' if not self.heatmap else 'heatmap')
     self.centernet.load_weights(self.model_path, by_name=True)
     print('{} model, anchors, and classes loaded.'.format(self.model_path))
Exemplo n.º 3
0
#--------------------------------------------#
#   该部分代码只用于看网络结构,并非测试代码
#   map测试请看get_dr_txt.py、get_gt_txt.py
#   和get_map.py
#--------------------------------------------#
from nets.centernet import centernet

if __name__ == "__main__":
    model = centernet([512, 512, 3], 20, backbone='resnet50')
    model.summary()

    for i, layer in enumerate(model.layers):
        print(i, layer.name)
Exemplo n.º 4
0
    #   获取classes和数量
    #----------------------------------------------------#
    class_names = get_classes(classes_path)
    num_classes = len(class_names)
    #-----------------------------#
    #   主干特征提取网络的选择
    #   resnet50
    #   hourglass
    #-----------------------------#
    backbone = "resnet50"

    #----------------------------------------------------#
    #   获取centernet模型
    #----------------------------------------------------#
    model = centernet(input_shape,
                      num_classes=num_classes,
                      backbone=backbone,
                      mode='train')

    #------------------------------------------------------#
    #   权值文件请看README,百度网盘下载
    #   训练自己的数据集时提示维度不匹配正常
    #   预测的东西都不一样了自然维度不匹配
    #------------------------------------------------------#
    model_path = r"model_data/centernet_resnet50_voc.h5"
    model.load_weights(model_path, by_name=True, skip_mismatch=True)

    #----------------------------------------------------#
    #   获得图片路径和标签
    #----------------------------------------------------#
    annotation_path = '2007_train.txt'
    #----------------------------------------------------------------------#
Exemplo n.º 5
0
    if ngpus_per_node > 1:
        strategy = tf.distribute.MirroredStrategy()
    else:
        strategy = None
    print('Number of devices: {}'.format(ngpus_per_node))

    #----------------------------------------------------#
    #   获取classes
    #----------------------------------------------------#
    class_names, num_classes = get_classes(classes_path)

    if ngpus_per_node > 1:
        with strategy.scope():
            model = centernet([input_shape[0], input_shape[1], 3],
                              num_classes=num_classes,
                              backbone=backbone,
                              mode='train')
            if model_path != '':
                #------------------------------------------------------#
                #   载入预训练权重
                #------------------------------------------------------#
                print('Load weights {}.'.format(model_path))
                model.load_weights(model_path,
                                   by_name=True,
                                   skip_mismatch=True)
    else:
        model = centernet([input_shape[0], input_shape[1], 3],
                          num_classes=num_classes,
                          backbone=backbone,
                          mode='train')
        if model_path != '':
Exemplo n.º 6
0
#--------------------------------------------#
#   该部分代码用于看网络结构
#--------------------------------------------#
from nets.centernet import centernet
from utils.utils import net_flops

if __name__ == "__main__":
    input_shape     = [512, 512]
    num_classes     = 20
    
    model = centernet([input_shape[0], input_shape[1], 3], num_classes, backbone='resnet50')
    #--------------------------------------------#
    #   查看网络结构网络结构
    #--------------------------------------------#
    model.summary()
    #--------------------------------------------#
    #   计算网络的FLOPS
    #--------------------------------------------#
    net_flops(model, table=False)
    
    #--------------------------------------------#
    #   获得网络每个层的名称与序号
    #--------------------------------------------#
    # for i,layer in enumerate(model.layers):
    #     print(i,layer.name)