def generate(self): self.net = YoloBody(len(self.anchors[0]), len(self.class_names)).eval() # 加快模型训练的效率 print('Loading weights into state dict...') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') state_dict = torch.load(self.model_path, map_location=device) self.net.load_state_dict(state_dict) if self.cuda: os.environ["CUDA_VISIBLE_DEVICES"] = '0' self.net = nn.DataParallel(self.net) self.net = self.net.cuda() print('Finished!') self.yolo_decodes = [] self.anchors_mask = [[3, 4, 5], [1, 2, 3]] for i in range(2): self.yolo_decodes.append( DecodeBox( np.reshape(self.anchors, [-1, 2])[self.anchors_mask[i]], len(self.class_names), (self.model_image_size[1], self.model_image_size[0]))) print('{} model, anchors, and classes loaded.'.format(self.model_path)) # 画框设置不同的颜色 hsv_tuples = [(x / len(self.class_names), 1., 1.) for x in range(len(self.class_names))] self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) self.colors = list( map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
#-------------------------------# # Dataloder的使用 #-------------------------------# Use_Data_Loader = True annotation_path = '2007_train.txt' #-------------------------------# # 获得先验框和类 #-------------------------------# anchors_path = 'model_data/yolo_anchors.txt' classes_path = 'model_data/voc_classes.txt' class_names = get_classes(classes_path) anchors = get_anchors(anchors_path) num_classes = len(class_names) # 创建模型 model = YoloBody(len(anchors[0]), num_classes) #-------------------------------------------# # 权值文件的下载请看README #-------------------------------------------# model_path = "model_data/yolov4_tiny_weights_coco.pth" # 加快模型训练的效率 print('Loading weights into state dict...') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_dict = model.state_dict() pretrained_dict = torch.load(model_path, map_location=device) pretrained_dict = { k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v) } model_dict.update(pretrained_dict)
#--------------------------------------------# # 该部分代码只用于看网络结构,并非测试代码 # map测试请看get_dr_txt.py、get_gt_txt.py # 和get_map.py #--------------------------------------------# import torch from torchsummary import summary from nets.yolo4_tiny import YoloBody if __name__ == "__main__": # 需要使用device来指定网络在GPU还是CPU运行 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = YoloBody(3,20).to(device) summary(model, input_size=(3, 416, 416))
#--------------------------------------------# # 该部分代码只用于看网络结构,并非测试代码 # map测试请看get_dr_txt.py、get_gt_txt.py # 和get_map.py #--------------------------------------------# import torch from torchsummary import summary from nets.yolo4_tiny import YoloBody if __name__ == "__main__": # 需要使用device来指定网络在GPU还是CPU运行 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = YoloBody(3,1).to(device) summary(model, input_size=(3, 768, 768))