def __init__( self, cfg, confidence_thresholds_for_classes, show_mask_heatmaps=False, masks_per_dim=2, min_image_size=224, ): self.cfg = cfg.clone() self.model = build_detection_model(cfg) self.model.eval() self.device = torch.device(cfg.MODEL.DEVICE) self.model.to(self.device) self.min_image_size = min_image_size save_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir) _ = checkpointer.load(cfg.MODEL.WEIGHT) self.transforms = self.build_transform() mask_threshold = -1 if show_mask_heatmaps else 0.5 self.masker = Masker(threshold=mask_threshold, padding=1) # used to make colors for each class self.palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) self.cpu_device = torch.device("cpu") self.confidence_thresholds_for_classes = torch.tensor( confidence_thresholds_for_classes) self.show_mask_heatmaps = show_mask_heatmaps self.masks_per_dim = masks_per_dim
def __init__( self, model_name="fcos_R_50_FPN_1x", nms_thresh=0.6, cpu_only=False ): root_dir = os.path.dirname(os.path.abspath(__file__)) self.config_files_dir = os.path.join(root_dir, "configs") self.cfg_name = model_name + ".yaml" cfg = base_cfg.clone() cfg.merge_from_file(os.path.join(self.config_files_dir, self.cfg_name)) cfg.MODEL.WEIGHT = _MODEL_NAMES_TO_INFO_[model_name]["url"] cfg.MODEL.FCOS.NMS_TH = nms_thresh if cpu_only: cfg.MODEL.DEVICE = "cpu" else: cfg.MODEL.DEVICE = "cuda" cfg.freeze() self.cfg = cfg self.model = build_detection_model(cfg) self.model.eval() self.device = torch.device(cfg.MODEL.DEVICE) self.model.to(self.device) checkpointer = DetectronCheckpointer(cfg, self.model) _ = checkpointer.load(cfg.MODEL.WEIGHT) self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) self.transforms = self.build_transform() self.cpu_device = torch.device("cpu") self.model_name = model_name
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 #断点定义,之后可以加载使用 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) #加载数据集 data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): #cfg 0 False model = build_detection_model(cfg) #实例化模型 device = torch.device(cfg.MODEL.DEVICE) #cfg.MODEL.DEVICE="cuda" 将torch.tensor分配到cuda 即GPU上 model.to(device) #将模型放在gpu上运行 if cfg.MODEL.USE_SYNCBN: #False assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) # 定义网络训练优化器 scheduler = make_lr_scheduler(cfg, optimizer) #设置学习率 if distributed: #False model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR #"." save_to_disk = get_rank() == 0 #True #checkpoint为网络的预训练模型 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) #将extra_checkpoint_data字典里的数值加入到arguments字典中 #make_data_loader data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, #False start_iter=arguments["iteration"], # 0 ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD #SOLVER.CHECKPOINT_PERIOD = 2500 do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def __init__(self, cfg, local_rank, distributed): self.writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR) self.start_epoch = 0 # self.epochs = cfg.MAX_ITER / len() self.epochs = 5 model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) # 核心修改在于dataset,dataloader都是torch.utils.data.data_loader # import pdb; pdb.set_trace() # train_loader = build_single_data_loader(cfg) self.train_loader = make_train_loader( cfg, start_iter=arguments["iteration"]) # self.val_loader = make_val_loader(cfg) # train_data_loader = make_data_loader( # cfg, # is_train=True, # is_distributed=distributed, # start_iter=arguments["iteration"], # ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD self.model = model self.optimizer = optimizer self.scheduler = scheduler self.checkpointer = checkpointer self.scheduler = scheduler self.device = device self.checkpoint_period = checkpoint_period self.arguments = arguments self.distributed = distributed
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) # import matplotlib.pyplot as plt # import numpy as np # # def imshow(img): # #img = img / 2 + 0.5 # unnormalize # img = img + 115 # img = img[[2, 1, 0]] # npimg = img.numpy().astype(np.int) # plt.imshow(np.transpose(npimg, (1, 2, 0))) # plt.show() # # import torchvision # dataiter = iter(data_loader) # images, target, _ = dataiter.next() #chwangteg target and pixel is hundreds # # imshow(torchvision.utils.make_grid(images.tensors)) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def main(): parser = argparse.ArgumentParser( description="PyTorch Object Detection Inference") parser.add_argument( "--config-file", default= "/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", metavar="FILE", help="path to config file", ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 if distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() save_dir = "" logger = setup_logger("fcos_core", save_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(cfg) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) output_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir) _ = checkpointer.load(cfg.MODEL.WEIGHT) iou_types = ("bbox", ) + ("segm", ) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm", ) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints", ) output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST if cfg.OUTPUT_DIR: for idx, dataset_name in enumerate(dataset_names): output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) for output_folder, dataset_name, data_loader_val in zip( output_folders, dataset_names, data_loaders_val): inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.SIPMASK_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) synchronize()
def create_model(cfg, device): cfg = copy.deepcopy(cfg) cfg.freeze() model = build_detection_model(cfg) model = model.to(device) return model
def main(): parser = argparse.ArgumentParser( description="PyTorch Object Detection Training") parser.add_argument( "--run-dir", default="run/fcos_imprv_R_50_FPN_1x/Baseline_lr1en4_191209", metavar="FILE", help="path to config file", type=str, ) args = parser.parse_args() # import pdb; pdb.set_trace() target_dir = args.run_dir dir_files = sorted(glob.glob(target_dir + '/*')) assert ( target_dir + '/new_config.yml' ) in dir_files, "Error! No cfg file found! check if the dir is right." cfg_file = target_dir + '/new_config.yml' if ( target_dir + '/new_config.yml') in dir_files else None model_files = [ f for f in dir_files if f.endswith('00.pth') and 'model_' in f ] tidyed_before = (target_dir + '/run_res_tidy') in dir_files if tidyed_before: import pdb pdb.set_trace() pass else: os.makedirs(target_dir + '/run_res_tidy') cfg.merge_from_file(cfg_file) cfg.freeze() logger = setup_logger("fcos_core", target_dir + '/run_res_tidy', 0, filename="test_log.txt") logger.info(cfg) # test_str = '' model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) checkpointer = DetectronCheckpointer(cfg, model, save_dir=target_dir + '/run_res_tidy/') iou_types = ("bbox", ) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm", ) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints", ) # output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST # if cfg.OUTPUT_DIR: # for idx, dataset_name in enumerate(dataset_names): # output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) # mkdir(output_folder) # output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False) dataset_name = dataset_names[0] data_loader_val = data_loaders_val[0] for i, model_f in enumerate(model_files): # import pdb; pdb.set_trace() _ = checkpointer.load(model_f) output_folder = target_dir + '/run_res_tidy/' + dataset_name + '_' + ( model_f.split('/')[-1][:-4]) os.makedirs(output_folder) logger.info('Processing {}/{}: {}'.format(i, len(model_f), output_folder)) # print('Processing {}/{}: {}'.format(i, len(model_f), output_folder)) inference_result = inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) summaryStrs = get_neat_inference_result(inference_result[2][0]) # test_str += '\n'+ output_folder.split('/')[-1]+ \ # '\n'.join(summaryStrs) logger.info(output_folder.split('/')[-1]) logger.info('\n'.join(summaryStrs))
def main(): parser = argparse.ArgumentParser( description="Export model to the onnx format") parser.add_argument( "--config-file", default="configs/fcos/fcos_imprv_R_50_FPN_1x.yaml", metavar="FILE", help="path to config file", ) parser.add_argument( "--output", default="fcos.onnx", metavar="FILE", help="path to the output onnx file", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() assert cfg.MODEL.FCOS_ON, "This script is only tested for the detector FCOS." save_dir = "" logger = setup_logger("fcos_core", save_dir, get_rank()) logger.info(cfg) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) output_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir) _ = checkpointer.load(cfg.MODEL.WEIGHT) onnx_model = torch.nn.Sequential( OrderedDict([ ('backbone', model.backbone), ('heads', model.rpn.head), ])) input_names = ["input_image"] dummy_input = torch.zeros((1, 3, 800, 1216)).to(cfg.MODEL.DEVICE) output_names = [] for l in range(len(cfg.MODEL.FCOS.FPN_STRIDES)): fpn_name = "P{}/".format(3 + l) output_names.extend([ fpn_name + "logits", fpn_name + "bbox_reg", fpn_name + "centerness" ]) torch.onnx.export(onnx_model, dummy_input, args.output, verbose=True, input_names=input_names, output_names=output_names, keep_initializers_as_inputs=True) logger.info("Done. The onnx model is saved into {}.".format(args.output))
def run_accCal(model_path, test_base_path, save_base_path, labels_dict, config_file, input_size=640, confidence_thresholds=(0.3, )): save_res_path = os.path.join(save_base_path, 'all') if os.path.exists(save_res_path): shutil.rmtree(save_res_path) os.mkdir(save_res_path) save_recall_path = os.path.join(save_base_path, 'recall') if os.path.exists(save_recall_path): shutil.rmtree(save_recall_path) os.mkdir(save_recall_path) save_ero_path = os.path.join(save_base_path, 'ero') if os.path.exists(save_ero_path): shutil.rmtree(save_ero_path) os.mkdir(save_ero_path) save_ori_path = os.path.join(save_base_path, 'ori') if os.path.exists(save_ori_path): shutil.rmtree(save_ori_path) os.mkdir(save_ori_path) test_img_path = os.path.join(test_base_path, 'VOC2007/JPEGImages') test_ano_path = os.path.join(test_base_path, 'VOC2007/Annotations') img_list = glob.glob(test_img_path + '/*.jpg') cfg.merge_from_file(config_file) cfg.MODEL.WEIGHT = model_path cfg.TEST.IMS_PER_BATCH = 1 # only test single image cfg.freeze() dbg_cfg = cfg model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) checkpointer = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) checkpointer.load(cfg.MODEL.WEIGHT) model.eval() normalize_transform = T.Normalize( mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD ) transform = T.Compose( [ T.ToPILImage(), T.Resize(input_size), T.ToTensor(), T.Lambda(lambda x: x * 255), normalize_transform, ] ) sad_accuracy = [0] * len(confidence_thresholds) sad_precision = [0] * len(confidence_thresholds) sad_recall = [0] * len(confidence_thresholds) spend_time = [] for idx, img_name in enumerate(img_list): progress(int(idx/len(img_list) * 100)) base_img_name = os.path.split(img_name)[-1] frame = cv2.imread(img_name) ori_frame = copy.deepcopy(frame) h, w = frame.shape[:2] image = transform(frame) image_list = to_image_list(image, cfg.DATALOADER.SIZE_DIVISIBILITY) image_list = image_list.to(cfg.MODEL.DEVICE) start_time = time.time() with torch.no_grad(): predictions = model(image_list) prediction = predictions[0].to("cpu") end_time = time.time() spend_time.append(end_time - start_time) prediction = prediction.resize((w, h)).convert("xyxy") # scores = prediction.get_field("scores") # keep = torch.nonzero(scores > confidence_threshold).squeeze(1) # prediction = prediction[keep] scores = prediction.get_field("scores") _, idx = scores.sort(0, descending=True) prediction = prediction[idx] scores = prediction.get_field("scores").numpy() labels = prediction.get_field("labels").numpy() bboxes = prediction.bbox.numpy().astype(np.int32) bboxes_area = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1]) for ii, confidence_threshold in enumerate(confidence_thresholds): _keep = np.where((scores > confidence_threshold) & (bboxes_area > 0), True, False) _scores = scores[_keep].tolist() _labels = labels[_keep].tolist() _bboxes = bboxes[_keep].tolist() _labels, _bboxes, _scores = soft_nms(_labels, _bboxes, _scores, confidence_threshold) if ii == 0: for i, b in enumerate(_bboxes): # save all frame = cv2.rectangle(frame, (b[0], b[1]), (b[2], b[3]), (100, 220, 200), 2) frame = cv2.putText(frame, str(_labels[i]) + '-' + str(int(_scores[i] * 100)), (b[0], b[1]), 1, 1, (0, 0, 255), 1) # cv2.imwrite(os.path.join(save_res_path, base_img_name), frame) boxes_list_tmp = copy.deepcopy(_bboxes) classes_list_tmp = copy.deepcopy(_labels) score_list_tmp = copy.deepcopy(_scores) fg_cnt = 0 recall_flag = False xml_name = base_img_name[:-4] + '.xml' anno_path = os.path.join(test_ano_path, xml_name) tree = ET.parse(anno_path) root = tree.getroot() rc_box = [] for siz in root.findall('size'): width_ = siz.find('width').text height_ = siz.find('height').text if not int(width_) or not int(height_): width_ = w height_ = h for obj in root.findall('object'): name = obj.find('name').text # class_tmp = get_cls(name, labels_dict) for bndbox in obj.findall('bndbox'): xmin = bndbox.find('xmin').text ymin = bndbox.find('ymin').text xmax = bndbox.find('xmax').text ymax = bndbox.find('ymax').text tmp_bbox = [int(int(xmin) * w / int(width_)), int(int(ymin) * h / int(height_)), int(int(xmax) * w / int(width_)), int(int(ymax) * h / int(height_))] map_flag = False for bbox_idx in range(len(boxes_list_tmp)): min_area, box_s, min_flag, iou_score = \ get_iou(tmp_bbox, boxes_list_tmp[bbox_idx]) if iou_score > 0.3: map_flag = True del classes_list_tmp[bbox_idx] del boxes_list_tmp[bbox_idx] del score_list_tmp[bbox_idx] break # 如果没找到匹配,属于漏检,算到召回率/检出率中 if not map_flag: recall_flag = True rc_box.append(tmp_bbox) fg_cnt = fg_cnt + 1 if recall_flag: sad_recall[ii] += 1 if ii == 0: for box_idx in range(len(rc_box)): x1, y1, x2, y2 = rc_box[box_idx] rca_frame = cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 4) cv2.imwrite(os.path.join(save_recall_path, base_img_name), rca_frame) shutil.copy(img_name, os.path.join(save_ori_path, base_img_name)) shutil.copy(anno_path, os.path.join(save_ori_path, xml_name)) # print("sad_recall: " + str(sad_recall)) # 如果有多出来的,属于误检,ground_truth中没有这个框,算到准确率中 if len(classes_list_tmp) > 0: sad_precision[ii] += 1 if ii == 0: for box_idx in range(len(boxes_list_tmp)): x1, y1, x2, y2 = boxes_list_tmp[box_idx] ero_frame = cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 4) err_rect_name = base_img_name[:-4] + '_' + str(box_idx) + '.jpg' cv2.imwrite(os.path.join(save_ero_path, err_rect_name), ori_frame[y1: y2, x1: x2, :]) cv2.imwrite(os.path.join(save_ero_path, base_img_name), ero_frame) shutil.copy(img_name, os.path.join(save_ori_path, base_img_name)) shutil.copy(anno_path, os.path.join(save_ori_path, xml_name)) if not recall_flag and len(classes_list_tmp) == 0: sad_accuracy[ii] += 1 # print("cur sad: " + str(sad)) # print("fg_cnt: " + str(fg_cnt)) # print("pred_cnt: " + str(len(classes_list_tmp))) # 单图所有框都检测正确才正确率,少一个框算漏检,多一个框算误检,不看mAP print('\nfps is : ', 1 / np.average(spend_time)) for ii, confidence_threshold in enumerate(confidence_thresholds): print("confidence th is : {}".format(confidence_threshold)) accuracy = float(sad_accuracy[ii] / len(img_list)) print("accuracy is : {}".format(accuracy)) precision = 1 - float(sad_precision[ii] / len(img_list)) print("precision is : {}".format(precision)) recall = 1 - float(sad_recall[ii] / len(img_list)) print("recall is : {}\n".format(recall))
def train(cfg, local_rank, distributed, iter_clear, ignore_head): model = build_detection_model(cfg) # model, conversion_count = convert_to_shift_dbg( # model, # cfg.DEEPSHIFT_DEPTH, # cfg.DEEPSHIFT_TYPE, # convert_weights=True, # use_kernel=cfg.DEEPSHIFT_USEKERNEL, # rounding=cfg.DEEPSHIFT_ROUNDING, # shift_range=cfg.DEEPSHIFT_RANGE) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 if iter_clear: load_opt = False load_sch = False else: load_opt = True load_sch = True if ignore_head: load_body = True load_fpn = True load_head = False else: load_body = True load_fpn = True load_head = True # 预加载模型或者是通常的模型,或者是deepshift模型 if cfg.MODEL.WEIGHT: checkpointer = DetectronCheckpointer( cfg, model, None, None, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, load_opt=False, load_sch=False, load_body=load_body, load_fpn=load_fpn, load_head=load_head) model, conversion_count = convert_to_shift( model, cfg.DEEPSHIFT_DEPTH, cfg.DEEPSHIFT_TYPE, convert_weights=True, use_kernel=cfg.DEEPSHIFT_USEKERNEL, rounding=cfg.DEEPSHIFT_ROUNDING, shift_range=cfg.DEEPSHIFT_RANGE) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) else: model, conversion_count = convert_to_shift( model, cfg.DEEPSHIFT_DEPTH, cfg.DEEPSHIFT_TYPE, convert_weights=True, use_kernel=cfg.DEEPSHIFT_USEKERNEL, rounding=cfg.DEEPSHIFT_ROUNDING, shift_range=cfg.DEEPSHIFT_RANGE) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, load_opt=False, load_sch=False, load_body=load_body, load_fpn=load_fpn, load_head=load_head) conv2d_layers_count = count_layer_type(model, torch.nn.Conv2d) linear_layers_count = count_layer_type(model, torch.nn.Linear) print("###### conversion_count: {}, not convert conv2d layer: {}, linear layer: {}".format( conversion_count, conv2d_layers_count, linear_layers_count)) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 arguments.update(extra_checkpoint_data) if iter_clear: arguments["iteration"] = 0 data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) model = round_shift_weights(model) torch.save({"model": model.state_dict()}, os.path.join(output_dir, "model_final_round.pth")) return model