class Efficientdet(object): _defaults = { #--------------------------------------------------------------------------# # 使用自己训练好的模型进行预测一定要修改model_path和classes_path! # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt # # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。 # 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。 # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 #--------------------------------------------------------------------------# "model_path" : 'model_data/efficientdet-d0.pth', "classes_path" : 'model_data/coco_classes.txt', #---------------------------------------------------------------------# # 用于选择所使用的模型的版本,0-7 #---------------------------------------------------------------------# "phi" : 0, #---------------------------------------------------------------------# # 只有得分大于置信度的预测框会被保留下来 #---------------------------------------------------------------------# "confidence" : 0.3, #---------------------------------------------------------------------# # 非极大抑制所用到的nms_iou大小 #---------------------------------------------------------------------# "nms_iou" : 0.3, #---------------------------------------------------------------------# # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize, # 在多次测试后,发现关闭letterbox_image直接resize的效果更好 #---------------------------------------------------------------------# "letterbox_image" : False, #---------------------------------------------------------------------# # 是否使用Cuda # 没有GPU可以设置成False #---------------------------------------------------------------------# "cuda" : True } @classmethod def get_defaults(cls, n): if n in cls._defaults: return cls._defaults[n] else: return "Unrecognized attribute name '" + n + "'" #---------------------------------------------------# # 初始化Efficientdet #---------------------------------------------------# def __init__(self, **kwargs): self.__dict__.update(self._defaults) for name, value in kwargs.items(): setattr(self, name, value) self.input_shape = [image_sizes[self.phi], image_sizes[self.phi]] #---------------------------------------------------# # 计算总的类的数量 #---------------------------------------------------# self.class_names, self.num_classes = get_classes(self.classes_path) #---------------------------------------------------# # 画框设置不同的颜色 #---------------------------------------------------# hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)] self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) self.generate() #---------------------------------------------------# # 载入模型 #---------------------------------------------------# def generate(self): #----------------------------------------# # 创建Efficientdet模型 #----------------------------------------# self.net = EfficientDetBackbone(self.num_classes, self.phi) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.net.load_state_dict(torch.load(self.model_path, map_location=device)) self.net = self.net.eval() print('{} model, anchors, and classes loaded.'.format(self.model_path)) if self.cuda: self.net = nn.DataParallel(self.net) self.net = self.net.cuda() #---------------------------------------------------# # 检测图片 #---------------------------------------------------# def detect_image(self, image, crop = False): #---------------------------------------------------# # 计算输入图片的高和宽 #---------------------------------------------------# image_shape = np.array(np.shape(image)[0:2]) #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# image = cvtColor(image) #---------------------------------------------------------# # 给图像增加灰条,实现不失真的resize # 也可以直接resize进行识别 #---------------------------------------------------------# image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) #---------------------------------------------------------# # 添加上batch_size维度,图片预处理,归一化。 #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) if self.cuda: images = images.cuda() #---------------------------------------------------------# # 传入网络当中进行预测 #---------------------------------------------------------# _, regression, classification, anchors = self.net(images) #-----------------------------------------------------------# # 将预测结果进行解码 #-----------------------------------------------------------# outputs = decodebox(regression, anchors, self.input_shape) results = non_max_suppression(torch.cat([outputs, classification], axis=-1), self.input_shape, image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) if results[0] is None: return image top_label = np.array(results[0][:, 5], dtype = 'int32') top_conf = results[0][:, 4] top_boxes = results[0][:, :4] #---------------------------------------------------------# # 设置字体与边框厚度 #---------------------------------------------------------# font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32')) thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1)) #---------------------------------------------------------# # 是否进行目标的裁剪 #---------------------------------------------------------# if crop: for i, c in list(enumerate(top_label)): top, left, bottom, right = top_boxes[i] top = max(0, np.floor(top).astype('int32')) left = max(0, np.floor(left).astype('int32')) bottom = min(image.size[1], np.floor(bottom).astype('int32')) right = min(image.size[0], np.floor(right).astype('int32')) dir_save_path = "img_crop" if not os.path.exists(dir_save_path): os.makedirs(dir_save_path) crop_image = image.crop([left, top, right, bottom]) crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0) print("save crop_" + str(i) + ".png to " + dir_save_path) #---------------------------------------------------------# # 图像绘制 #---------------------------------------------------------# for i, c in list(enumerate(top_label)): predicted_class = self.class_names[int(c)] box = top_boxes[i] score = top_conf[i] top, left, bottom, right = box top = max(0, np.floor(top).astype('int32')) left = max(0, np.floor(left).astype('int32')) bottom = min(image.size[1], np.floor(bottom).astype('int32')) right = min(image.size[0], np.floor(right).astype('int32')) label = '{} {:.2f}'.format(predicted_class, score) draw = ImageDraw.Draw(image) label_size = draw.textsize(label, font) label = label.encode('utf-8') print(label, top, left, bottom, right) if top - label_size[1] >= 0: text_origin = np.array([left, top - label_size[1]]) else: text_origin = np.array([left, top + 1]) for i in range(thickness): draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c]) draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c]) draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font) del draw return image def get_FPS(self, image, test_interval): image_shape = np.array(np.shape(image)[0:2]) #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# image = cvtColor(image) #---------------------------------------------------------# # 给图像增加灰条,实现不失真的resize # 也可以直接resize进行识别 #---------------------------------------------------------# image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) #---------------------------------------------------------# # 添加上batch_size维度,图片预处理,归一化。 #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) if self.cuda: images = images.cuda() #---------------------------------------------------------# # 传入网络当中进行预测 #---------------------------------------------------------# _, regression, classification, anchors = self.net(images) #-----------------------------------------------------------# # 将预测结果进行解码 #-----------------------------------------------------------# outputs = decodebox(regression, anchors, self.input_shape) results = non_max_suppression(torch.cat([outputs, classification], axis=-1), self.input_shape, image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) t1 = time.time() for _ in range(test_interval): with torch.no_grad(): #---------------------------------------------------------# # 传入网络当中进行预测 #---------------------------------------------------------# _, regression, classification, anchors = self.net(images) #-----------------------------------------------------------# # 将预测结果进行解码 #-----------------------------------------------------------# outputs = decodebox(regression, anchors, self.input_shape) results = non_max_suppression(torch.cat([outputs, classification], axis=-1), self.input_shape, image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) t2 = time.time() tact_time = (t2 - t1) / test_interval return tact_time #---------------------------------------------------# # 检测图片 #---------------------------------------------------# def get_map_txt(self, image_id, image, class_names, map_out_path): f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") image_shape = np.array(np.shape(image)[0:2]) #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# image = cvtColor(image) #---------------------------------------------------------# # 给图像增加灰条,实现不失真的resize # 也可以直接resize进行识别 #---------------------------------------------------------# image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) #---------------------------------------------------------# # 添加上batch_size维度,图片预处理,归一化。 #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) if self.cuda: images = images.cuda() #---------------------------------------------------------# # 传入网络当中进行预测 #---------------------------------------------------------# _, regression, classification, anchors = self.net(images) #-----------------------------------------------------------# # 将预测结果进行解码 #-----------------------------------------------------------# outputs = decodebox(regression, anchors, self.input_shape) results = non_max_suppression(torch.cat([outputs, classification], axis=-1), self.input_shape, image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) if results[0] is None: return top_label = np.array(results[0][:, 5], dtype = 'int32') top_conf = results[0][:, 4] top_boxes = results[0][:, :4] for i, c in list(enumerate(top_label)): predicted_class = self.class_names[int(c)] box = top_boxes[i] score = str(top_conf[i]) top, left, bottom, right = box if predicted_class not in class_names: continue f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) f.close() return
model = EfficientDetBackbone(num_classes, phi) #------------------------------------------------------# # 权值文件请看README,百度网盘下载 #------------------------------------------------------# model_path = "model_data/efficientdet-d0.pth" print('Loading weights into state dict...') model_dict = model.state_dict() pretrained_dict = torch.load(model_path) pretrained_dict = { k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v) } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) print('Finished!') net = model.train() if Cuda: net = torch.nn.DataParallel(model) cudnn.benchmark = True net = net.cuda() efficient_loss = FocalLoss() #----------------------------------------------------# # 获得图片路径和标签 #----------------------------------------------------# annotation_path = '2007_train.txt'
class EfficientDet(object): _defaults = { "model_path": 'model_data/efficientdet-d0.pth', "classes_path": 'model_data/coco_classes.txt', "phi": 0, "confidence": 0.3, "cuda": True } @classmethod def get_defaults(cls, n): if n in cls._defaults: return cls._defaults[n] else: return "Unrecognized attribute name '" + n + "'" #---------------------------------------------------# # 初始化Efficientdet #---------------------------------------------------# def __init__(self, **kwargs): self.__dict__.update(self._defaults) self.class_names = self._get_class() self.generate() #---------------------------------------------------# # 获得所有的分类 #---------------------------------------------------# def _get_class(self): classes_path = os.path.expanduser(self.classes_path) with open(classes_path) as f: class_names = f.readlines() class_names = [c.strip() for c in class_names] return class_names #---------------------------------------------------# # 获得所有的分类 #---------------------------------------------------# def generate(self): os.environ["CUDA_VISIBLE_DEVICES"] = '0' self.net = EfficientDetBackbone(len(self.class_names), self.phi).eval() # 加快模型训练的效率 print('Loading weights into state dict...') state_dict = torch.load(self.model_path) self.net.load_state_dict(state_dict) self.net = nn.DataParallel(self.net) if self.cuda: self.net = self.net.cuda() print('Finished!') print('{} model, anchors, and classes loaded.'.format(self.model_path)) # 画框设置不同的颜色 hsv_tuples = [(x / len(self.class_names), 1., 1.) for x in range(len(self.class_names))] self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) self.colors = list( map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) #---------------------------------------------------# # 检测图片 #---------------------------------------------------# def detect_image(self, image): image_shape = np.array(np.shape(image)[0:2]) crop_img = np.array( letterbox_image(image, (image_sizes[self.phi], image_sizes[self.phi]))) photo = np.array(crop_img, dtype=np.float32) photo = np.transpose(preprocess_input(photo), (2, 0, 1)) images = [] images.append(photo) images = np.asarray(images) with torch.no_grad(): images = torch.from_numpy(images) if self.cuda: images = images.cuda() _, regression, classification, anchors = self.net(images) regression = decodebox(regression, anchors, images) detection = torch.cat([regression, classification], axis=-1) batch_detections = non_max_suppression(detection, len(self.class_names), conf_thres=self.confidence, nms_thres=0.2) try: batch_detections = batch_detections[0].cpu().numpy() except: return image top_index = batch_detections[:, 4] > self.confidence top_conf = batch_detections[top_index, 4] top_label = np.array(batch_detections[top_index, -1], np.int32) top_bboxes = np.array(batch_detections[top_index, :4]) top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims( top_bboxes[:, 0], -1), np.expand_dims(top_bboxes[:, 1], -1), np.expand_dims( top_bboxes[:, 2], -1), np.expand_dims(top_bboxes[:, 3], -1) # 去掉灰条 boxes = efficientdet_correct_boxes( top_ymin, top_xmin, top_ymax, top_xmax, np.array([image_sizes[self.phi], image_sizes[self.phi]]), image_shape) font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32')) thickness = (np.shape(image)[0] + np.shape(image)[1]) // image_sizes[self.phi] for i, c in enumerate(top_label): predicted_class = self.class_names[c] score = top_conf[i] top, left, bottom, right = boxes[i] top = top - 5 left = left - 5 bottom = bottom + 5 right = right + 5 top = max(0, np.floor(top + 0.5).astype('int32')) left = max(0, np.floor(left + 0.5).astype('int32')) bottom = min( np.shape(image)[0], np.floor(bottom + 0.5).astype('int32')) right = min( np.shape(image)[1], np.floor(right + 0.5).astype('int32')) # 画框框 label = '{} {:.2f}'.format(predicted_class, score) draw = ImageDraw.Draw(image) label_size = draw.textsize(label, font) label = label.encode('utf-8') print(label) if top - label_size[1] >= 0: text_origin = np.array([left, top - label_size[1]]) else: text_origin = np.array([left, top + 1]) for i in range(thickness): draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[self.class_names.index( predicted_class)]) draw.rectangle( [tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[self.class_names.index(predicted_class)]) draw.text(text_origin, str(label, 'UTF-8'), fill=(0, 0, 0), font=font) del draw return image
class train_model(object): def __init__(self): cfg = configs.parse_config() self.input_shape = (cfg.input_sizes[cfg.phi], cfg.input_sizes[cfg.phi]) self.lr_first = cfg.learning_rate_first_stage self.Batch_size_first = cfg.Batch_size_first_stage self.Init_Epoch = cfg.Init_Epoch self.Freeze_Epoch = cfg.Freeze_Epoch self.opt_weight_decay = cfg.opt_weight_decay self.CosineAnnealingLR_T_max = cfg.CosineAnnealingLR_T_max self.CosineAnnealingLR_eta_min = cfg.CosineAnnealingLR_eta_min self.StepLR_step_size = cfg.StepLR_step_size self.StepLR_gamma = cfg.StepLR_gamma self.num_workers = cfg.num_workers self.Save_num_epoch = cfg.Save_num_epoch self.lr_second = cfg.learning_rate_second_stage self.Batch_size_second = cfg.Batch_size_second_stage self.Unfreeze_Epoch = cfg.Unfreeze_Epoch # TODO:tricks的使用设置 self.Cosine_lr, self.mosaic = cfg.Cosine_lr, cfg.use_mosaic self.Cuda = torch.cuda.is_available() self.smoooth_label = cfg.smoooth_label self.Use_Data_Loader, self.annotation_path = cfg.Use_Data_Loader, cfg.train_annotation_path # TODO:获得类 self.classes_path = cfg.classes_path self.class_names = self.get_classes(self.classes_path) self.num_classes = len(self.class_names) # TODO:创建模型 self.model = EfficientDetBackbone(self.num_classes, cfg.phi) pretrain_weight_name = os.listdir(cfg.pretrain_dir) index = [item for item in pretrain_weight_name if str(cfg.phi) in item][0] # 加快模型训练的效率 print('Loading pretrain_weights into state dict...') model_dict = self.model.state_dict() pretrained_dict = torch.load(cfg.pretrain_dir + index) pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} model_dict.update(pretrained_dict) self.model.load_state_dict(model_dict) print('Finished!') self.net = self.model.train() if self.Cuda: self.net = torch.nn.DataParallel(self.model) # 多GPU进行训练,但这个设置有问题 cudnn.benchmark = True self.net = self.net.cuda() # TODO:建立loss函数 self.efficient_loss = FocalLoss() # cfg.val_split用于验证,1-cfg.val_split用于训练 val_split = cfg.val_split with open(self.annotation_path) as f: self.lines = f.readlines() np.random.seed(101) np.random.shuffle(self.lines) np.random.seed(None) self.num_val = int(len(self.lines) * val_split) self.num_train = len(self.lines) - self.num_val self.train_first_stage() self.train_second_stage() def get_classes(self, classes_path): """ loads the classes name :param classes_path: :return: """ with open(classes_path) as f: class_names = f.readlines() class_names = [c.strip() for c in class_names] return class_names def get_lr(self, optimizer): for param_group in optimizer.param_groups: return param_group['lr'] def fit_one_epoch(self, net, model, optimizer, focal_loss, epoch, epoch_size, epoch_size_val, gen, genval, Epoch, cuda): """ :param net: 网络 :param yolo_losses: yolo损失类 :param epoch: 第几个epoch :param epoch_size: train中每个epoch中有多少个数据 :param epoch_size_val: valid中每个epoch里面的数据 :param gen: train DataLoader :param genval: valid DataLoader :param Epoch: 截至epoch :param cuda: :return: """ total_r_loss = 0 total_c_loss = 0 total_loss = 0 val_loss = 0 start_time = time.time() with tqdm(total=epoch_size, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3) as pbar: for iteration, batch in enumerate(gen): if iteration >= epoch_size: break images, targets = batch[0], batch[1] with torch.no_grad(): if cuda: images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda() targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)).cuda() for ann in targets] else: images = Variable(torch.from_numpy(images).type(torch.FloatTensor)) targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets] optimizer.zero_grad() # regression shape is (batch_size, H*W*9+……, 4), classification shape is (batch_size, H*W*9+……, num_classes) _, regression, classification, anchors = net(images) # anchors先验框 loss, c_loss, r_loss = focal_loss(classification, regression, anchors, targets, cuda=cuda) loss.backward() optimizer.step() total_loss += loss.detach().item() total_r_loss += r_loss.detach().item() total_c_loss += c_loss.detach().item() waste_time = time.time() - start_time pbar.set_postfix(**{'Conf Loss': total_c_loss / (iteration + 1), 'Regression Loss': total_r_loss / (iteration + 1), 'lr': self.get_lr(optimizer), 'step/s': waste_time}) pbar.update(1) start_time = time.time() print('Start Validation') with tqdm(total=epoch_size_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3) as pbar: for iteration, batch in enumerate(genval): if iteration >= epoch_size_val: break images_val, targets_val = batch[0], batch[1] with torch.no_grad(): if cuda: images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)).cuda() targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)).cuda() for ann in targets_val] else: images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)) targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val] optimizer.zero_grad() _, regression, classification, anchors = net(images_val) loss, c_loss, r_loss = focal_loss(classification, regression, anchors, targets_val, cuda=cuda) val_loss += loss.detach().item() pbar.set_postfix(**{'total_loss': val_loss / (iteration + 1)}) pbar.update(1) print('Finish Validation') print('Epoch:' + str(epoch + 1) + '/' + str(Epoch)) print('Total Loss: %.4f || Val Loss: %.4f ' % (total_loss / (epoch_size + 1), val_loss / (epoch_size_val + 1))) if (epoch + 1) % self.Save_num_epoch == 0: print('Saving state, iter:', str(epoch + 1)) torch.save(model.state_dict(), 'model_weight/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth' % ( (epoch + 1), total_loss / (epoch_size + 1), val_loss / (epoch_size_val + 1))) def train_first_stage(self): """ 主干特征提取网络特征通用,冻结训练可以加快训练速度 也可以在训练初期防止权值被破坏。 Init_Epoch为起始世代 Freeze_Epoch为冻结训练的世代 Epoch总训练世代 提示OOM或者显存不足请调小Batch_size :return: """ optimizer_stage1 = optim.Adam(self.net.parameters(), self.lr_first, weight_decay=self.opt_weight_decay) if self.Cosine_lr: lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer_stage1, T_max=self.CosineAnnealingLR_T_max, eta_min=self.CosineAnnealingLR_eta_min) else: lr_scheduler = optim.lr_scheduler.StepLR( optimizer_stage1, step_size=self.StepLR_step_size, gamma=self.StepLR_gamma) if self.Use_Data_Loader: train_dataset = EfficientdetDataset(self.lines[:self.num_train], (self.input_shape[0], self.input_shape[1])) val_dataset = EfficientdetDataset(self.lines[self.num_train:], (self.input_shape[0], self.input_shape[1])) gen = DataLoader(train_dataset, batch_size=self.Batch_size_first, num_workers=self.num_workers, pin_memory=True, drop_last=True, collate_fn=efficientdet_dataset_collate) gen_val = DataLoader(val_dataset, batch_size=self.Batch_size_first, num_workers=self.num_workers, pin_memory=True, drop_last=True, collate_fn=efficientdet_dataset_collate) else: gen = Generator(self.Batch_size_first, self.lines[:self.num_train], (self.input_shape[0], self.input_shape[1])).generate() gen_val = Generator(self.Batch_size_first, self.lines[self.num_train:], (self.input_shape[0], self.input_shape[1])).generate() epoch_size = max(1, self.num_train // self.Batch_size_first) epoch_size_val = self.num_val // self.Batch_size_first # ------------------------------------# # 冻结一定部分训练 # ------------------------------------# for param in self.model.backbone_net.parameters(): param.requires_grad = False for epoch in range(self.Init_Epoch, self.Freeze_Epoch): self.fit_one_epoch(self.net, self.model, optimizer_stage1, self.efficient_loss, epoch, epoch_size, epoch_size_val, gen, gen_val, self.Freeze_Epoch, self.Cuda) lr_scheduler.step() def train_second_stage(self): """ 整个网络的参数一起更新 :return: """ optimizer_stage2 = optim.Adam(self.net.parameters(), self.lr_second, weight_decay=self.opt_weight_decay) if self.Cosine_lr: lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer_stage2, T_max=self.CosineAnnealingLR_T_max, eta_min=self.CosineAnnealingLR_eta_min) else: lr_scheduler = optim.lr_scheduler.StepLR( optimizer_stage2, step_size=self.StepLR_step_size, gamma=self.StepLR_gamma) if self.Use_Data_Loader: train_dataset = EfficientdetDataset(self.lines[:self.num_train], (self.input_shape[0], self.input_shape[1])) val_dataset = EfficientdetDataset(self.lines[self.num_train:], (self.input_shape[0], self.input_shape[1])) gen = DataLoader(train_dataset, batch_size=self.Batch_size_first, num_workers=self.num_workers, pin_memory=True, drop_last=True, collate_fn=efficientdet_dataset_collate) gen_val = DataLoader(val_dataset, batch_size=self.Batch_size_first, num_workers=self.num_workers, pin_memory=True, drop_last=True, collate_fn=efficientdet_dataset_collate) else: gen = Generator(self.Batch_size_first, self.lines[:self.num_train], (self.input_shape[0], self.input_shape[1])).generate() gen_val = Generator(self.Batch_size_first, self.lines[self.num_train:], (self.input_shape[0], self.input_shape[1])).generate() epoch_size = max(1, self.num_train // self.Batch_size_second) epoch_size_val = self.num_val // self.Batch_size_second # ------------------------------------# # 解冻后训练 # ------------------------------------# for param in self.model.backbone_net.parameters(): param.requires_grad = True for epoch in range(self.Freeze_Epoch, self.Unfreeze_Epoch): self.fit_one_epoch(self.net, self.model, optimizer_stage2, self.efficient_loss, epoch, epoch_size, epoch_size_val, gen, gen_val, self.Unfreeze_Epoch, self.Cuda) lr_scheduler.step()