def generate(self): # 计算总的种类 self.num_classes = len(self.class_names) assert self.backbone in ['resnet50', 'hourglass'] if self.backbone == "resnet50": self.centernet = CenterNet_Resnet50(num_classes=self.num_classes, pretrain=False) else: self.centernet = CenterNet_HourglassNet({ 'hm': self.num_classes, 'wh': 2, 'reg': 2 }) print('Loading weights into state dict...') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') state_dict = torch.load(self.model_path, map_location=device) self.centernet.load_state_dict(state_dict, strict=True) self.centernet = self.centernet.eval() if self.cuda: os.environ["CUDA_VISIBLE_DEVICES"] = '0' self.centernet = nn.DataParallel(self.centernet) self.centernet.cuda() print('{} model, anchors, and classes loaded.'.format(self.model_path)) # 画框设置不同的颜色 hsv_tuples = [(x / len(self.class_names), 1., 1.) for x in range(len(self.class_names))] self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) self.colors = list( map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
def generate(self, onnx=False): #-------------------------------# # 载入模型与权值 #-------------------------------# assert self.backbone in ['resnet50', 'hourglass'] if self.backbone == "resnet50": self.net = CenterNet_Resnet50(num_classes=self.num_classes, pretrained=False) else: self.net = CenterNet_HourglassNet({ 'hm': self.num_classes, 'wh': 2, 'reg': 2 }) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.net.load_state_dict( torch.load(self.model_path, map_location=device)) self.net = self.net.eval() print('{} model, and classes loaded.'.format(self.model_path)) if not onnx: if self.cuda: self.net = torch.nn.DataParallel(self.net) self.net = self.net.cuda()
#--------------------------------------------# # 该部分代码只用于看网络结构,并非测试代码 # map测试请看get_dr_txt.py、get_gt_txt.py # 和get_map.py #--------------------------------------------# import torch from torchsummary import summary from nets.centernet import CenterNet_HourglassNet, CenterNet_Resnet50 if __name__ == "__main__": # model = CenterNet_HourglassNet({'hm': 80, 'wh': 2, 'reg':2}).train().cuda() # summary(model,(3,128,128)) model = CenterNet_Resnet50().train().cuda() summary(model, (3, 512, 512))
# 是否使用imagenet-resnet50的预训练权重。 # 仅在主干网络为resnet50时有作用。 # 默认为False #-------------------------------------------# pretrain = False #-------------------------------------------# # 主干特征提取网络的选择 # resnet50和hourglass #-------------------------------------------# backbone = "resnet50" Cuda = True assert backbone in ['resnet50', 'hourglass'] if backbone == "resnet50": model = CenterNet_Resnet50(num_classes, pretrain=pretrain) else: model = CenterNet_HourglassNet({'hm': num_classes, 'wh': 2, 'reg':2}) #------------------------------------------------------# # 权值文件请看README,百度网盘下载 #------------------------------------------------------# model_path = r"model_data/centernet_resnet50_voc.pth" # 加快模型训练的效率 print('Loading weights into state dict...') model_dict = model.state_dict() pretrained_dict = torch.load(model_path) pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) print('Finished!')
# 修改成自己需要分的类 # -----------------------------# classes_path = hyp.get('classes_path') # ----------------------------------------------------# # 获取classes和数量 # ----------------------------------------------------# class_names = get_classes(classes_path) num_classes = len(class_names) # ----------------------------------------------------# # 获取centernet模型 # ----------------------------------------------------# model = CenterNet_Resnet50(num_classes) # ------------------------------------------------------# # 权值文件请看README,百度网盘下载 # ------------------------------------------------------# model_path = hyp.get('model_path') if model_path: print('Loading weights into state dict...') model_dict = model.state_dict() pretrained_dict = torch.load(model_path) pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) print('Finished!')
#--------------------------------------------# # 该部分代码用于看网络参数 #--------------------------------------------# import torch from thop import clever_format, profile from torchsummary import summary from nets.centernet import CenterNet_HourglassNet, CenterNet_Resnet50 if __name__ == "__main__": input_shape = [512, 512] num_classes = 20 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CenterNet_Resnet50().to(device) summary(model, (3, input_shape[0], input_shape[1])) dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) flops, params = profile(model.to(device), (dummy_input, ), verbose=False) #--------------------------------------------------------# # flops * 2是因为profile没有将卷积作为两个operations # 有些论文将卷积算乘法、加法两个operations。此时乘2 # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 # 本代码选择乘2,参考YOLOX。 #--------------------------------------------------------# flops = flops * 2 flops, params = clever_format([flops, params], "%.3f") print('Total GFLOPS: %s' % (flops)) print('Total params: %s' % (params))