示例#1
0
    def generate(self):
        # 计算总的种类
        self.num_classes = len(self.class_names)

        assert self.backbone in ['resnet50', 'hourglass']
        if self.backbone == "resnet50":
            self.centernet = CenterNet_Resnet50(num_classes=self.num_classes,
                                                pretrain=False)
        else:
            self.centernet = CenterNet_HourglassNet({
                'hm': self.num_classes,
                'wh': 2,
                'reg': 2
            })

        print('Loading weights into state dict...')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        state_dict = torch.load(self.model_path, map_location=device)
        self.centernet.load_state_dict(state_dict, strict=True)
        self.centernet = self.centernet.eval()

        if self.cuda:
            os.environ["CUDA_VISIBLE_DEVICES"] = '0'
            self.centernet = nn.DataParallel(self.centernet)
            self.centernet.cuda()

        print('{} model, anchors, and classes loaded.'.format(self.model_path))

        # 画框设置不同的颜色
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))
示例#2
0
    def generate(self, onnx=False):
        #-------------------------------#
        #   载入模型与权值
        #-------------------------------#
        assert self.backbone in ['resnet50', 'hourglass']
        if self.backbone == "resnet50":
            self.net = CenterNet_Resnet50(num_classes=self.num_classes,
                                          pretrained=False)
        else:
            self.net = CenterNet_HourglassNet({
                'hm': self.num_classes,
                'wh': 2,
                'reg': 2
            })

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(
            torch.load(self.model_path, map_location=device))
        self.net = self.net.eval()
        print('{} model, and classes loaded.'.format(self.model_path))
        if not onnx:
            if self.cuda:
                self.net = torch.nn.DataParallel(self.net)
                self.net = self.net.cuda()
示例#3
0
#--------------------------------------------#
#   该部分代码只用于看网络结构,并非测试代码
#   map测试请看get_dr_txt.py、get_gt_txt.py
#   和get_map.py
#--------------------------------------------#
import torch
from torchsummary import summary

from nets.centernet import CenterNet_HourglassNet, CenterNet_Resnet50

if __name__ == "__main__":
    # model = CenterNet_HourglassNet({'hm': 80, 'wh': 2, 'reg':2}).train().cuda()
    # summary(model,(3,128,128))
    model = CenterNet_Resnet50().train().cuda()
    summary(model, (3, 512, 512))
示例#4
0
    #   是否使用imagenet-resnet50的预训练权重。
    #   仅在主干网络为resnet50时有作用。
    #   默认为False
    #-------------------------------------------#
    pretrain = False
    #-------------------------------------------#
    #   主干特征提取网络的选择
    #   resnet50和hourglass
    #-------------------------------------------#
    backbone = "resnet50"  

    Cuda = True

    assert backbone in ['resnet50', 'hourglass']
    if backbone == "resnet50":
        model = CenterNet_Resnet50(num_classes, pretrain=pretrain)
    else:
        model = CenterNet_HourglassNet({'hm': num_classes, 'wh': 2, 'reg':2})

    #------------------------------------------------------#
    #   权值文件请看README,百度网盘下载
    #------------------------------------------------------#
    model_path = r"model_data/centernet_resnet50_voc.pth"
    # 加快模型训练的效率
    print('Loading weights into state dict...')
    model_dict = model.state_dict() 
    pretrained_dict = torch.load(model_path)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print('Finished!')
示例#5
0
    #   修改成自己需要分的类
    # -----------------------------#
    classes_path = hyp.get('classes_path')
    # ----------------------------------------------------#
    #   获取classes和数量
    # ----------------------------------------------------#
    class_names = get_classes(classes_path)
    num_classes = len(class_names)




    # ----------------------------------------------------#
    #   获取centernet模型
    # ----------------------------------------------------#
    model = CenterNet_Resnet50(num_classes)


    # ------------------------------------------------------#
    #   权值文件请看README,百度网盘下载
    # ------------------------------------------------------#
    model_path = hyp.get('model_path')
    if model_path:
        print('Loading weights into state dict...')
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_path)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        print('Finished!')
示例#6
0
#--------------------------------------------#
#   该部分代码用于看网络参数
#--------------------------------------------#
import torch
from thop import clever_format, profile
from torchsummary import summary

from nets.centernet import CenterNet_HourglassNet, CenterNet_Resnet50

if __name__ == "__main__":
    input_shape = [512, 512]
    num_classes = 20

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CenterNet_Resnet50().to(device)
    summary(model, (3, input_shape[0], input_shape[1]))

    dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
    flops, params = profile(model.to(device), (dummy_input, ), verbose=False)
    #--------------------------------------------------------#
    #   flops * 2是因为profile没有将卷积作为两个operations
    #   有些论文将卷积算乘法、加法两个operations。此时乘2
    #   有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
    #   本代码选择乘2,参考YOLOX。
    #--------------------------------------------------------#
    flops = flops * 2
    flops, params = clever_format([flops, params], "%.3f")
    print('Total GFLOPS: %s' % (flops))
    print('Total params: %s' % (params))