def train(cfg, local_rank, distributed): model = LabelEncStep1Network(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def run_test(cfg, model, distributed): if distributed: model["backbone"] = model["backbone"].module model["fcos"] = model["fcos"].module #if cfg.MODEL.ADV.USE_DIS_P7: # model["dis_P7"] = model["dis_P7"].module #if cfg.MODEL.ADV.USE_DIS_P6: # model["dis_P6"] = model["dis_P6"].module #if cfg.MODEL.ADV.USE_DIS_P5: # model["dis_P5"] = model["dis_P5"].module #if cfg.MODEL.ADV.USE_DIS_P4: # model["dis_P4"] = model["dis_P4"].module #if cfg.MODEL.ADV.USE_DIS_P3: # model["dis_P3"] = model["dis_P3"].module torch.cuda.empty_cache() # TODO check if it helps iou_types = ("bbox", ) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm", ) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints", ) output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST if cfg.OUTPUT_DIR: for idx, dataset_name in enumerate(dataset_names): output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) for output_folder, dataset_name, data_loader_val in zip( output_folders, dataset_names, data_loaders_val): inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) synchronize()
def run_test(cfg, model, distributed): model_test = {} if distributed: model_test["backbone"] = model["backbone"].module model_test["fcos"] = model["fcos"].module #if cfg.MODEL.ADV.USE_DIS_P7: # model["dis_P7"] = model["dis_P7"].module #if cfg.MODEL.ADV.USE_DIS_P6: # model["dis_P6"] = model["dis_P6"].module #if cfg.MODEL.ADV.USE_DIS_P5: # model["dis_P5"] = model["dis_P5"].module #if cfg.MODEL.ADV.USE_DIS_P4: # model["dis_P4"] = model["dis_P4"].module #if cfg.MODEL.ADV.USE_DIS_P3: # model["dis_P3"] = model["dis_P3"].module torch.cuda.empty_cache() # TODO check if it helps iou_types = ("bbox", ) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm", ) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints", ) dataset_name = cfg.DATASETS.TEST[0] if cfg.OUTPUT_DIR: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) results = inference( model_test, data_loaders_val[0], dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) synchronize() results = all_gather(results) # import pdb; pdb.set_trace() return results
def run_test(cfg, model, distributed): if distributed: model = model.module torch.cuda.empty_cache() # TODO check if it helps iou_types = ("bbox", ) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm", ) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints", ) output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST if cfg.OUTPUT_DIR: for idx, dataset_name in enumerate(dataset_names): output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) for output_folder, dataset_name, data_loader_val in zip( output_folders, dataset_names, data_loaders_val): inference_result = inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) synchronize() # import pdb; pdb.set_trace() summaryStrs = get_neat_inference_result(inference_result[2][0]) # print('\n'.join(summaryStrs)) with open(output_folder + '/summaryStrs.txt', 'w') as f_summaryStrs: f_summaryStrs.write('\n'.join(summaryStrs))
def main(): parser = argparse.ArgumentParser(description="Test onnx models of FCOS") parser.add_argument( "--config-file", default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", metavar="FILE", help="path to config file", ) parser.add_argument( "--onnx-model", default="fcos_imprv_R_50_FPN_1x.onnx", metavar="FILE", help="path to the onnx model", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) # The onnx model can only be used with DATALOADER.NUM_WORKERS = 0 cfg.DATALOADER.NUM_WORKERS = 0 cfg.freeze() save_dir = "" logger = setup_logger("fcos_core", save_dir, get_rank()) logger.info(cfg) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) model = ONNX_FCOS(args.onnx_model, cfg) model.to(cfg.MODEL.DEVICE) iou_types = ("bbox",) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm",) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints",) output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST if cfg.OUTPUT_DIR: for idx, dataset_name in enumerate(dataset_names): output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False) for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) synchronize()
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) # import matplotlib.pyplot as plt # import numpy as np # # def imshow(img): # #img = img / 2 + 0.5 # unnormalize # img = img + 115 # img = img[[2, 1, 0]] # npimg = img.numpy().astype(np.int) # plt.imshow(np.transpose(npimg, (1, 2, 0))) # plt.show() # # import torchvision # dataiter = iter(data_loader) # images, target, _ = dataiter.next() #chwangteg target and pixel is hundreds # # imshow(torchvision.utils.make_grid(images.tensors)) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def main(): parser = argparse.ArgumentParser( description="PyTorch Object Detection Inference") parser.add_argument( "--config-file", default= "/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", metavar="FILE", help="path to config file", ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 if distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() save_dir = "" logger = setup_logger("fcos_core", save_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(cfg) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) output_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir) _ = checkpointer.load(cfg.MODEL.WEIGHT) iou_types = ("bbox", ) + ("segm", ) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm", ) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints", ) output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST if cfg.OUTPUT_DIR: for idx, dataset_name in enumerate(dataset_names): output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) for output_folder, dataset_name, data_loader_val in zip( output_folders, dataset_names, data_loaders_val): inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.SIPMASK_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) synchronize()
def train(cfg, local_rank, distributed, labelenc_fpath): model = LabelEncStep2Network(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) # Load LabelEncodingFunction # Initialize FPN and Head from Step1 weights if not checkpointer.has_checkpoint(): labelenc_weights = torch.load(labelenc_fpath, map_location=torch.device('cpu')) # load LabelEncodingFunction model.module.label_encoding_function.load_state_dict( labelenc_weights['label_encoding_function'], strict=True) # Initialize Head model.module.rpn.load_state_dict( labelenc_weights['rpn'], strict=True) if model.module.roi_heads: model.module.roi_heads.load_state_dict( labelenc_weights['roi_heads'], strict=True) # Initialize FPN fpn_weight = model.module.label_encoding_function.fpn.state_dict() model.module.backbone.fpn.load_state_dict(fpn_weight, strict=True) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def main(): parser = argparse.ArgumentParser( description="PyTorch Object Detection Training") parser.add_argument( "--run-dir", default="run/fcos_imprv_R_50_FPN_1x/Baseline_lr1en4_191209", metavar="FILE", help="path to config file", type=str, ) args = parser.parse_args() # import pdb; pdb.set_trace() target_dir = args.run_dir dir_files = sorted(glob.glob(target_dir + '/*')) assert ( target_dir + '/new_config.yml' ) in dir_files, "Error! No cfg file found! check if the dir is right." cfg_file = target_dir + '/new_config.yml' if ( target_dir + '/new_config.yml') in dir_files else None model_files = [ f for f in dir_files if f.endswith('00.pth') and 'model_' in f ] tidyed_before = (target_dir + '/run_res_tidy') in dir_files if tidyed_before: import pdb pdb.set_trace() pass else: os.makedirs(target_dir + '/run_res_tidy') cfg.merge_from_file(cfg_file) cfg.freeze() logger = setup_logger("fcos_core", target_dir + '/run_res_tidy', 0, filename="test_log.txt") logger.info(cfg) # test_str = '' model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) checkpointer = DetectronCheckpointer(cfg, model, save_dir=target_dir + '/run_res_tidy/') iou_types = ("bbox", ) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm", ) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints", ) # output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST # if cfg.OUTPUT_DIR: # for idx, dataset_name in enumerate(dataset_names): # output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) # mkdir(output_folder) # output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False) dataset_name = dataset_names[0] data_loader_val = data_loaders_val[0] for i, model_f in enumerate(model_files): # import pdb; pdb.set_trace() _ = checkpointer.load(model_f) output_folder = target_dir + '/run_res_tidy/' + dataset_name + '_' + ( model_f.split('/')[-1][:-4]) os.makedirs(output_folder) logger.info('Processing {}/{}: {}'.format(i, len(model_f), output_folder)) # print('Processing {}/{}: {}'.format(i, len(model_f), output_folder)) inference_result = inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) summaryStrs = get_neat_inference_result(inference_result[2][0]) # test_str += '\n'+ output_folder.split('/')[-1]+ \ # '\n'.join(summaryStrs) logger.info(output_folder.split('/')[-1]) logger.info('\n'.join(summaryStrs))
def train(cfg, local_rank, distributed, device_ids, use_tensorboard=False): # ------------------------------- more configs half_data = [0, 0] # do not split first_order = cfg.SOLVER.SEARCH.FIRST_ORDER alpha_lr = cfg.SOLVER.SEARCH.BASE_LR_ALPHA alpha_weight_decay = 1e-3 device_ids = [int(x) for x in device_ids] if cfg.MODEL.FAD.CLSTOWER or cfg.MODEL.FAD.BOXTOWER: n_cells = cfg.MODEL.FAD.NUM_CELLS_CLS if cfg.MODEL.FAD.CLSTOWER and cfg.MODEL.FAD.BOXTOWER: n_nodes = cfg.MODEL.FAD.NUM_NODES_CLS n_module = 2 elif cfg.MODEL.FAD.CLSTOWER: n_nodes = cfg.MODEL.FAD.NUM_NODES_CLS n_module = 1 else: n_nodes = cfg.MODEL.FAD.NUM_NODES_BOX n_module = 1 else: pdb.set_trace() # build model model = SearchRCNNController(n_cells, n_nodes=n_nodes, device_ids=device_ids, cfg_det=cfg, n_module=n_module) device = torch.device(cfg.MODEL.DEVICE) model = model.to(device) torch.cuda.set_device(0) distributed = False if first_order: print('Using 1st order approximationfor the search') if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) # ---------------------- optimize alpha arch = Architect(model, cfg.SOLVER.MOMENTUM, cfg.SOLVER.WEIGHT_DECAY) alpha_optim = torch.optim.Adam(model.alphas(), alpha_lr, betas=(0.5, 0.999), weight_decay=alpha_weight_decay) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR # ------ tensorboard tb_info = {"tb_logger": None} if use_tensorboard: tb_logger = get_tensorboard_writer(output_dir) tb_info['tb_logger'] = tb_logger tb_info['prefix'] = cfg.TENSORBOARD.PREFIX save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader(cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], half=half_data[0]) val_loader = make_data_loader(cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], half=half_data[1]) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, arch, data_loader, val_loader, optimizer, alpha_optim, scheduler, checkpointer, device, checkpoint_period, arguments, cfg, tb_info=tb_info, first_order=first_order, ) return model
def train(cfg, local_rank, distributed, iter_clear, ignore_head): model = build_detection_model(cfg) # model, conversion_count = convert_to_shift_dbg( # model, # cfg.DEEPSHIFT_DEPTH, # cfg.DEEPSHIFT_TYPE, # convert_weights=True, # use_kernel=cfg.DEEPSHIFT_USEKERNEL, # rounding=cfg.DEEPSHIFT_ROUNDING, # shift_range=cfg.DEEPSHIFT_RANGE) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 if iter_clear: load_opt = False load_sch = False else: load_opt = True load_sch = True if ignore_head: load_body = True load_fpn = True load_head = False else: load_body = True load_fpn = True load_head = True # 预加载模型或者是通常的模型,或者是deepshift模型 if cfg.MODEL.WEIGHT: checkpointer = DetectronCheckpointer( cfg, model, None, None, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, load_opt=False, load_sch=False, load_body=load_body, load_fpn=load_fpn, load_head=load_head) model, conversion_count = convert_to_shift( model, cfg.DEEPSHIFT_DEPTH, cfg.DEEPSHIFT_TYPE, convert_weights=True, use_kernel=cfg.DEEPSHIFT_USEKERNEL, rounding=cfg.DEEPSHIFT_ROUNDING, shift_range=cfg.DEEPSHIFT_RANGE) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) else: model, conversion_count = convert_to_shift( model, cfg.DEEPSHIFT_DEPTH, cfg.DEEPSHIFT_TYPE, convert_weights=True, use_kernel=cfg.DEEPSHIFT_USEKERNEL, rounding=cfg.DEEPSHIFT_ROUNDING, shift_range=cfg.DEEPSHIFT_RANGE) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, load_opt=False, load_sch=False, load_body=load_body, load_fpn=load_fpn, load_head=load_head) conv2d_layers_count = count_layer_type(model, torch.nn.Conv2d) linear_layers_count = count_layer_type(model, torch.nn.Linear) print("###### conversion_count: {}, not convert conv2d layer: {}, linear layer: {}".format( conversion_count, conv2d_layers_count, linear_layers_count)) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 arguments.update(extra_checkpoint_data) if iter_clear: arguments["iteration"] = 0 data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) model = round_shift_weights(model) torch.save({"model": model.state_dict()}, os.path.join(output_dir, "model_final_round.pth")) return model