def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) print(model) print("") print("##############################################################") print("") #summary(model, (3, 2048, 1024)) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD #print(f"{bcolors.WARNING}Warning: No active frommets remain. Continue?{bcolors.ENDC}") do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file", type=str, ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "--skip-test", dest="skip_test", help="Do not test the final model", action="store_true", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 args.distributed = num_gpus > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) synchronize() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() output_dir = cfg.OUTPUT_DIR if output_dir: mkdir(output_dir) logger = setup_logger("fcos_core", output_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(args) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) model = train(cfg, args.local_rank, args.distributed) if not args.skip_test: run_test(cfg, model, args.distributed)
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) #summary(model,input_size=(2,3,1333,800)) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, cfg.MODEL.CLS_WEIGHT, cfg.MODEL.REG_WEIGHT, init_div=True, init_opti=False, init_model=True) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): #cfg 0 False model = build_detection_model(cfg) #实例化模型 device = torch.device(cfg.MODEL.DEVICE) #cfg.MODEL.DEVICE="cuda" 将torch.tensor分配到cuda 即GPU上 model.to(device) #将模型放在gpu上运行 if cfg.MODEL.USE_SYNCBN: #False assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) # 定义网络训练优化器 scheduler = make_lr_scheduler(cfg, optimizer) #设置学习率 if distributed: #False model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR #"." save_to_disk = get_rank() == 0 #True #checkpoint为网络的预训练模型 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) #将extra_checkpoint_data字典里的数值加入到arguments字典中 #make_data_loader data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, #False start_iter=arguments["iteration"], # 0 ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD #SOLVER.CHECKPOINT_PERIOD = 2500 do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed, labelenc_fpath): model = LabelEncStep2Network(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) # Load LabelEncodingFunction # Initialize FPN and Head from Step1 weights if not checkpointer.has_checkpoint(): labelenc_weights = torch.load(labelenc_fpath, map_location=torch.device('cpu')) # load LabelEncodingFunction model.module.label_encoding_function.load_state_dict( labelenc_weights['label_encoding_function'], strict=True) # Initialize Head model.module.rpn.load_state_dict( labelenc_weights['rpn'], strict=True) if model.module.roi_heads: model.module.roi_heads.load_state_dict( labelenc_weights['roi_heads'], strict=True) # Initialize FPN fpn_weight = model.module.label_encoding_function.fpn.state_dict() model.module.backbone.fpn.load_state_dict(fpn_weight, strict=True) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) # 利用build_detection_model构建model device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: # syncbn是什么,SyncBatchNorm assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) # 对model进行转换,转换成sync的 optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: # 是否使用分布式训练,distributed 分布式的 model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} # 创建字典 arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk # checkpoint ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( # 利用make_data_loader读取数据 cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def main(): # 这个就是解析命令行参数,如上面的--config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml parser = argparse.ArgumentParser( description="PyTorch Object Detection Training") parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file", type=str, ) # 这个参数是torch.distributed.launch传递过来的,我们设置位置参数来接受 # local_rank代表当前程序进程使用的GPU标号 parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "--skip-test", dest="skip_test", help="Do not test the final model", action="store_true", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() # 判断机器上GPU的数量,大于1时自动使用分布式训练 # WORLD_SIZE 由torch.distributed.launch.py产生 # 具体数值为 nproc_per_node*node(node就是主机数) num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 args.distributed = num_gpus > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() # 参数默认是在fcos_core/config/defaults.py中,其余由config_file,opts覆盖 cfg.merge_from_file(args.config_file) # 从yaml文件中读取参数 cfg.merge_from_list(args.opts) # 也可以从命令行参数重写 cfg.freeze() # 冻住参数,为了防止之后被不小心更改,cfg被传入train() # 可以在这里打印cfg看看,我以fcos_R_50_FPN_1x.yaml为例 output_dir = cfg.OUTPUT_DIR # 创建输出文件夹,存放一些日志信息 if output_dir: mkdir(output_dir) # 写入日志文件,包括GPU数量,系统环境,配置文件参数等 logger = setup_logger("fcos_core", output_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(args) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) # 这句话是下一个入口,关注train()方法,里面第一步就是构建模型 model = train(cfg, args.local_rank, args.distributed) if not args.skip_test: run_test(cfg, model, args.distributed)
def train(cfg, local_rank, distributed, iter_clear, ignore_head): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) if iter_clear: load_opt = False load_sch = False else: load_opt = True load_sch = True if ignore_head: load_body = True load_fpn = True load_head = False else: load_body = True load_fpn = True load_head = True extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, load_opt=load_opt, load_sch=load_sch, load_body=load_body, load_fpn=load_fpn, load_head=load_head) arguments.update(extra_checkpoint_data) if iter_clear: arguments["iteration"] = 0 data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def main(): parser = argparse.ArgumentParser(description="Test onnx models of FCOS") parser.add_argument( "--config-file", default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", metavar="FILE", help="path to config file", ) parser.add_argument( "--onnx-model", default="fcos_imprv_R_50_FPN_1x.onnx", metavar="FILE", help="path to the onnx model", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) # The onnx model can only be used with DATALOADER.NUM_WORKERS = 0 cfg.DATALOADER.NUM_WORKERS = 0 cfg.freeze() save_dir = "" logger = setup_logger("fcos_core", save_dir, get_rank()) logger.info(cfg) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) model = ONNX_FCOS(args.onnx_model, cfg) model.to(cfg.MODEL.DEVICE) iou_types = ("bbox",) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm",) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints",) output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST if cfg.OUTPUT_DIR: for idx, dataset_name in enumerate(dataset_names): output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False) for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) synchronize()
def main(): # 解析命令行参数,例如--config-file parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") parser.add_argument( "--config-file", #配置文件 default="", metavar="FILE", help="path to config file", type=str, ) #此参数是通过torch.distributed.launch传递过来的,我们设置位置参数来接受 # local_rank代表当前程序进程使用的GPU标号 parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "--skip-test", dest="skip_test", help="Do not test the final model", action="store_true", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, #所有剩余的命令行参数都被收集到一个列表中 ) args = parser.parse_args() #判断机器上gpu的数量,大于1时自动使用分布式训练 #world_size是由torch.distributed.launch.py产生 # 具体数值为 nproc_per_node*node(node就是主机数) num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 #判断当前系统环境变量中是否有"WORLD_SIZE" 如果没有num_gpus=1 args.distributed = num_gpus > 1 #False if args.distributed: #False torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group\ ( backend="nccl", init_method="env://" ) synchronize() #yacs的具体用法 可以参考印象笔记 #参数默认在fcos_core/config_defaults.py中 其余参数由config_file opts覆盖 cfg.merge_from_file(args.config_file) #从yaml文件中读取参数 即configs/fcos/fcos_R_50_FPN_1x.yaml cfg.merge_from_list(args.opts) #也可以从命令行进行参数重写 cfg.freeze() #冻结参数 防止不小心被更改 cfg被传入train() output_dir = cfg.OUTPUT_DIR #输出模型路径 存放一些日志信息 if output_dir: mkdir(output_dir) #创建对应的输出路径 #写入日志文件 包括gpu数量,系统环境,配置文件参数等 logger = setup_logger("fcos_core", output_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(args) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) model = train(cfg, args.local_rank, args.distributed) #local_rank=0 distributed=False if not args.skip_test: run_test(cfg, model, args.distributed)
def train(cfg, local_rank, distributed, device_ids, use_tensorboard=False): # ------------------------------- more configs half_data = [0, 0] # do not split first_order = cfg.SOLVER.SEARCH.FIRST_ORDER alpha_lr = cfg.SOLVER.SEARCH.BASE_LR_ALPHA alpha_weight_decay = 1e-3 device_ids = [int(x) for x in device_ids] if cfg.MODEL.FAD.CLSTOWER or cfg.MODEL.FAD.BOXTOWER: n_cells = cfg.MODEL.FAD.NUM_CELLS_CLS if cfg.MODEL.FAD.CLSTOWER and cfg.MODEL.FAD.BOXTOWER: n_nodes = cfg.MODEL.FAD.NUM_NODES_CLS n_module = 2 elif cfg.MODEL.FAD.CLSTOWER: n_nodes = cfg.MODEL.FAD.NUM_NODES_CLS n_module = 1 else: n_nodes = cfg.MODEL.FAD.NUM_NODES_BOX n_module = 1 else: pdb.set_trace() # build model model = SearchRCNNController(n_cells, n_nodes=n_nodes, device_ids=device_ids, cfg_det=cfg, n_module=n_module) device = torch.device(cfg.MODEL.DEVICE) model = model.to(device) torch.cuda.set_device(0) distributed = False if first_order: print('Using 1st order approximationfor the search') if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) # ---------------------- optimize alpha arch = Architect(model, cfg.SOLVER.MOMENTUM, cfg.SOLVER.WEIGHT_DECAY) alpha_optim = torch.optim.Adam(model.alphas(), alpha_lr, betas=(0.5, 0.999), weight_decay=alpha_weight_decay) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR # ------ tensorboard tb_info = {"tb_logger": None} if use_tensorboard: tb_logger = get_tensorboard_writer(output_dir) tb_info['tb_logger'] = tb_logger tb_info['prefix'] = cfg.TENSORBOARD.PREFIX save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader(cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], half=half_data[0]) val_loader = make_data_loader(cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], half=half_data[1]) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, arch, data_loader, val_loader, optimizer, alpha_optim, scheduler, checkpointer, device, checkpoint_period, arguments, cfg, tb_info=tb_info, first_order=first_order, ) return model
def main(): parser = argparse.ArgumentParser( description="PyTorch Object Detection Training") parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file", type=str, ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument("--device_ids", type=list, default=[0]) parser.add_argument( "--skip-test", dest="skip_test", help="Do not test the final model", action="store_true", ) parser.add_argument( "--use-tensorboard", dest="use_tensorboard", help="Use tensorboardX logger (Requires tensorboardX installed)", action="store_true", default=False) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() # set devices_ids according to num gpus num_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) args.device_ids = list(map(str, range(num_gpus))) # do not use torch.distributed args.distributed = False if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() output_dir = cfg.OUTPUT_DIR if output_dir: mkdir(output_dir) logger = setup_logger("fad_core", output_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(args) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) model = train(cfg, args.local_rank, args.distributed, args.device_ids, use_tensorboard=args.use_tensorboard) if not args.skip_test: run_test(cfg, model, args.distributed)
def train(cfg, local_rank, distributed): writer = SummaryWriter('runs/{}'.format(cfg.OUTPUT_DIR)) ########################################################################## ############################# Initial Model ############################## ########################################################################## model = {} device = torch.device(cfg.MODEL.DEVICE) backbone = build_backbone(cfg).to(device) fcos = build_rpn(cfg, backbone.out_channels).to(device) if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P7_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P7, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P6_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P6, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P5_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P5, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P4_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P4, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P3_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P3, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P7_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P7, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P6_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P6, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P5_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P5, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P4_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P4, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P3_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P3, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P7_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P7, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P6_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P6, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P5_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P5, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P4_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P4, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P3_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P3, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P7_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P7, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P6_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P6, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P5_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P5, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P4_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P4, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P3_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P3, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone) fcos = torch.nn.SyncBatchNorm.convert_sync_batchnorm(fcos) if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P7) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P6) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P5) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P4) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P3) if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P7_CA) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P6_CA) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P5_CA) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P4_CA) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P3_CA) if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P7_Cond) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P6_Cond) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P5_Cond) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P4_Cond) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P3_Cond) if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P7_HA) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P6_HA) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P5_HA) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P4_HA) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P3_HA) ########################################################################## #################### Initial Optimizer and Scheduler ##################### ########################################################################## optimizer = {} optimizer["backbone"] = make_optimizer(cfg, backbone, name='backbone') optimizer["fcos"] = make_optimizer(cfg, fcos, name='fcos') if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: optimizer["dis_P7"] = make_optimizer(cfg, dis_P7, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: optimizer["dis_P6"] = make_optimizer(cfg, dis_P6, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: optimizer["dis_P5"] = make_optimizer(cfg, dis_P5, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: optimizer["dis_P4"] = make_optimizer(cfg, dis_P4, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: optimizer["dis_P3"] = make_optimizer(cfg, dis_P3, name='discriminator') if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: optimizer["dis_P7_CA"] = make_optimizer(cfg, dis_P7_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: optimizer["dis_P6_CA"] = make_optimizer(cfg, dis_P6_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: optimizer["dis_P5_CA"] = make_optimizer(cfg, dis_P5_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: optimizer["dis_P4_CA"] = make_optimizer(cfg, dis_P4_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: optimizer["dis_P3_CA"] = make_optimizer(cfg, dis_P3_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: optimizer["dis_P7_Cond"] = make_optimizer(cfg, dis_P7_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: optimizer["dis_P6_Cond"] = make_optimizer(cfg, dis_P6_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: optimizer["dis_P5_Cond"] = make_optimizer(cfg, dis_P5_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: optimizer["dis_P4_Cond"] = make_optimizer(cfg, dis_P4_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: optimizer["dis_P3_Cond"] = make_optimizer(cfg, dis_P3_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: optimizer["dis_P7_HA"] = make_optimizer(cfg, dis_P7_HA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: optimizer["dis_P6_HA"] = make_optimizer(cfg, dis_P6_HA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: optimizer["dis_P5_HA"] = make_optimizer(cfg, dis_P5_HA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: optimizer["dis_P4_HA"] = make_optimizer(cfg, dis_P4_HA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: optimizer["dis_P3_HA"] = make_optimizer(cfg, dis_P3_HA, name='discriminator') scheduler = {} scheduler["backbone"] = make_lr_scheduler(cfg, optimizer["backbone"], name='backbone') scheduler["fcos"] = make_lr_scheduler(cfg, optimizer["fcos"], name='fcos') if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: scheduler["dis_P7"] = make_lr_scheduler(cfg, optimizer["dis_P7"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: scheduler["dis_P6"] = make_lr_scheduler(cfg, optimizer["dis_P6"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: scheduler["dis_P5"] = make_lr_scheduler(cfg, optimizer["dis_P5"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: scheduler["dis_P4"] = make_lr_scheduler(cfg, optimizer["dis_P4"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: scheduler["dis_P3"] = make_lr_scheduler(cfg, optimizer["dis_P3"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: scheduler["dis_P7_CA"] = make_lr_scheduler(cfg, optimizer["dis_P7_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: scheduler["dis_P6_CA"] = make_lr_scheduler(cfg, optimizer["dis_P6_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: scheduler["dis_P5_CA"] = make_lr_scheduler(cfg, optimizer["dis_P5_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: scheduler["dis_P4_CA"] = make_lr_scheduler(cfg, optimizer["dis_P4_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: scheduler["dis_P3_CA"] = make_lr_scheduler(cfg, optimizer["dis_P3_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: scheduler["dis_P7_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P7_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: scheduler["dis_P6_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P6_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: scheduler["dis_P5_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P5_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: scheduler["dis_P4_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P4_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: scheduler["dis_P3_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P3_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: scheduler["dis_P7_HA"] = make_lr_scheduler(cfg, optimizer["dis_P7_HA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: scheduler["dis_P6_HA"] = make_lr_scheduler(cfg, optimizer["dis_P6_HA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: scheduler["dis_P5_HA"] = make_lr_scheduler(cfg, optimizer["dis_P5_HA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: scheduler["dis_P4_HA"] = make_lr_scheduler(cfg, optimizer["dis_P4_HA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: scheduler["dis_P3_HA"] = make_lr_scheduler(cfg, optimizer["dis_P3_HA"], name='discriminator') ########################################################################## ######################## DistributedDataParallel ######################### ########################################################################## if distributed: backbone = torch.nn.parallel.DistributedDataParallel( backbone, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) fcos = torch.nn.parallel.DistributedDataParallel( fcos, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7 = torch.nn.parallel.DistributedDataParallel( dis_P7, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6 = torch.nn.parallel.DistributedDataParallel( dis_P6, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5 = torch.nn.parallel.DistributedDataParallel( dis_P5, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4 = torch.nn.parallel.DistributedDataParallel( dis_P4, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3 = torch.nn.parallel.DistributedDataParallel( dis_P3, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_CA = torch.nn.parallel.DistributedDataParallel( dis_P7_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_CA = torch.nn.parallel.DistributedDataParallel( dis_P6_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_CA = torch.nn.parallel.DistributedDataParallel( dis_P5_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_CA = torch.nn.parallel.DistributedDataParallel( dis_P4_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_CA = torch.nn.parallel.DistributedDataParallel( dis_P3_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_Cond = torch.nn.parallel.DistributedDataParallel( dis_P7_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_Cond = torch.nn.parallel.DistributedDataParallel( dis_P6_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_Cond = torch.nn.parallel.DistributedDataParallel( dis_P5_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_Cond = torch.nn.parallel.DistributedDataParallel( dis_P4_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_Cond = torch.nn.parallel.DistributedDataParallel( dis_P3_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_HA = torch.nn.parallel.DistributedDataParallel( dis_P7_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_HA = torch.nn.parallel.DistributedDataParallel( dis_P6_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_HA = torch.nn.parallel.DistributedDataParallel( dis_P5_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_HA = torch.nn.parallel.DistributedDataParallel( dis_P4_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_HA = torch.nn.parallel.DistributedDataParallel( dis_P3_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) ########################################################################## ########################### Save Model to Dict ########################### ########################################################################## model["backbone"] = backbone model["fcos"] = fcos if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: model["dis_P7"] = dis_P7 if cfg.MODEL.ADV.USE_DIS_P6: model["dis_P6"] = dis_P6 if cfg.MODEL.ADV.USE_DIS_P5: model["dis_P5"] = dis_P5 if cfg.MODEL.ADV.USE_DIS_P4: model["dis_P4"] = dis_P4 if cfg.MODEL.ADV.USE_DIS_P3: model["dis_P3"] = dis_P3 if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: model["dis_P7_CA"] = dis_P7_CA if cfg.MODEL.ADV.USE_DIS_P6: model["dis_P6_CA"] = dis_P6_CA if cfg.MODEL.ADV.USE_DIS_P5: model["dis_P5_CA"] = dis_P5_CA if cfg.MODEL.ADV.USE_DIS_P4: model["dis_P4_CA"] = dis_P4_CA if cfg.MODEL.ADV.USE_DIS_P3: model["dis_P3_CA"] = dis_P3_CA if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: model["dis_P7_Cond"] = dis_P7_Cond if cfg.MODEL.ADV.USE_DIS_P6: model["dis_P6_Cond"] = dis_P6_Cond if cfg.MODEL.ADV.USE_DIS_P5: model["dis_P5_Cond"] = dis_P5_Cond if cfg.MODEL.ADV.USE_DIS_P4: model["dis_P4_Cond"] = dis_P4_Cond if cfg.MODEL.ADV.USE_DIS_P3: model["dis_P3_Cond"] = dis_P3_Cond if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: model["dis_P7_HA"] = dis_P7_HA if cfg.MODEL.ADV.USE_DIS_P6: model["dis_P6_HA"] = dis_P6_HA if cfg.MODEL.ADV.USE_DIS_P5: model["dis_P5_HA"] = dis_P5_HA if cfg.MODEL.ADV.USE_DIS_P4: model["dis_P4_HA"] = dis_P4_HA if cfg.MODEL.ADV.USE_DIS_P3: model["dis_P3_HA"] = dis_P3_HA ########################################################################## ################################ Training ################################ ########################################################################## arguments = {} arguments["iteration"] = 0 arguments["use_dis_global"] = cfg.MODEL.ADV.USE_DIS_GLOBAL arguments["use_dis_ca"] = cfg.MODEL.ADV.USE_DIS_CENTER_AWARE arguments["use_dis_conditional"] = cfg.MODEL.ADV.USE_DIS_CONDITIONAL arguments["use_dis_ha"] = cfg.MODEL.ADV.USE_DIS_HEAD arguments["ga_dis_lambda"] = cfg.MODEL.ADV.GA_DIS_LAMBDA arguments["ca_dis_lambda"] = cfg.MODEL.ADV.CA_DIS_LAMBDA arguments["cond_dis_lambda"] = cfg.MODEL.ADV.COND_DIS_LAMBDA arguments["ha_dis_lambda"] = cfg.MODEL.ADV.HA_DIS_LAMBDA arguments["use_feature_layers"] = [] if cfg.MODEL.ADV.USE_DIS_P7: arguments["use_feature_layers"].append("P7") if cfg.MODEL.ADV.USE_DIS_P6: arguments["use_feature_layers"].append("P6") if cfg.MODEL.ADV.USE_DIS_P5: arguments["use_feature_layers"].append("P5") if cfg.MODEL.ADV.USE_DIS_P4: arguments["use_feature_layers"].append("P4") if cfg.MODEL.ADV.USE_DIS_P3: arguments["use_feature_layers"].append("P3") output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(f=cfg.MODEL.WEIGHT, load_dis=True, load_opt_sch=False) # arguments.update(extra_checkpoint_data) # Initial dataloader (both target and source domain) data_loader = {} data_loader["source"] = make_data_loader_source( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) data_loader["target"] = make_data_loader_target( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train(model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, cfg, run_test, distributed, writer) return model
def train(cfg, local_rank, distributed, iter_clear, ignore_head): model = build_detection_model(cfg) # model, conversion_count = convert_to_shift_dbg( # model, # cfg.DEEPSHIFT_DEPTH, # cfg.DEEPSHIFT_TYPE, # convert_weights=True, # use_kernel=cfg.DEEPSHIFT_USEKERNEL, # rounding=cfg.DEEPSHIFT_ROUNDING, # shift_range=cfg.DEEPSHIFT_RANGE) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 if iter_clear: load_opt = False load_sch = False else: load_opt = True load_sch = True if ignore_head: load_body = True load_fpn = True load_head = False else: load_body = True load_fpn = True load_head = True # 预加载模型或者是通常的模型,或者是deepshift模型 if cfg.MODEL.WEIGHT: checkpointer = DetectronCheckpointer( cfg, model, None, None, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, load_opt=False, load_sch=False, load_body=load_body, load_fpn=load_fpn, load_head=load_head) model, conversion_count = convert_to_shift( model, cfg.DEEPSHIFT_DEPTH, cfg.DEEPSHIFT_TYPE, convert_weights=True, use_kernel=cfg.DEEPSHIFT_USEKERNEL, rounding=cfg.DEEPSHIFT_ROUNDING, shift_range=cfg.DEEPSHIFT_RANGE) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) else: model, conversion_count = convert_to_shift( model, cfg.DEEPSHIFT_DEPTH, cfg.DEEPSHIFT_TYPE, convert_weights=True, use_kernel=cfg.DEEPSHIFT_USEKERNEL, rounding=cfg.DEEPSHIFT_ROUNDING, shift_range=cfg.DEEPSHIFT_RANGE) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, load_opt=False, load_sch=False, load_body=load_body, load_fpn=load_fpn, load_head=load_head) conv2d_layers_count = count_layer_type(model, torch.nn.Conv2d) linear_layers_count = count_layer_type(model, torch.nn.Linear) print("###### conversion_count: {}, not convert conv2d layer: {}, linear layer: {}".format( conversion_count, conv2d_layers_count, linear_layers_count)) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 arguments.update(extra_checkpoint_data) if iter_clear: arguments["iteration"] = 0 data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) model = round_shift_weights(model) torch.save({"model": model.state_dict()}, os.path.join(output_dir, "model_final_round.pth")) return model
def main(): parser = argparse.ArgumentParser( description="PyTorch Object Detection Training") parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file", type=str, ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "--skip-test", dest="skip_test", help="Do not test the final model", action="store_true", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 args.distributed = num_gpus > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) # add distance loss warmup iters cfg.SOLVER.MAX_ITER += cfg.MODEL.LABELENC.DISTANCE_LOSS_WARMUP_ITERS cfg.SOLVER.STEPS = tuple([ i + cfg.MODEL.LABELENC.DISTANCE_LOSS_WARMUP_ITERS for i in cfg.SOLVER.STEPS ]) cfg.freeze() output_dir = cfg.OUTPUT_DIR if output_dir: mkdir(output_dir) logger = setup_logger("fcos_core", output_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(args) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) model = train(cfg, args.local_rank, args.distributed) if not args.skip_test: run_test(cfg, model, args.distributed) if args.distributed: model = model.module if not args.distributed or dist.get_rank() == 0: label_encoding_function = model.label_encoding_function.state_dict() rpn = model.rpn.state_dict() saved_weights = { 'label_encoding_function': label_encoding_function, 'rpn': rpn } if model.roi_heads: roi_heads = model.roi_heads.state_dict() saved_weights.update({'roi_heads': roi_heads}) torch.save(saved_weights, os.path.join(cfg.OUTPUT_DIR, "label_encoding_function.pth")) logger.info("Successfully save label encoding function weights to " + \ os.path.join(cfg.OUTPUT_DIR, "label_encoding_function.pth")) synchronize()
def main(): parser = argparse.ArgumentParser( description="PyTorch Object Detection Inference") parser.add_argument( "--config-file", default= "/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", metavar="FILE", help="path to config file", ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 if distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() save_dir = "" logger = setup_logger("fcos_core", save_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(cfg) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) output_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir) _ = checkpointer.load(cfg.MODEL.WEIGHT) iou_types = ("bbox", ) + ("segm", ) if cfg.MODEL.MASK_ON: iou_types = iou_types + ("segm", ) if cfg.MODEL.KEYPOINT_ON: iou_types = iou_types + ("keypoints", ) output_folders = [None] * len(cfg.DATASETS.TEST) dataset_names = cfg.DATASETS.TEST if cfg.OUTPUT_DIR: for idx, dataset_name in enumerate(dataset_names): output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) for output_folder, dataset_name, data_loader_val in zip( output_folders, dataset_names, data_loaders_val): inference( model, data_loader_val, dataset_name=dataset_name, iou_types=iou_types, box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.SIPMASK_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, ) synchronize()
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) # import matplotlib.pyplot as plt # import numpy as np # # def imshow(img): # #img = img / 2 + 0.5 # unnormalize # img = img + 115 # img = img[[2, 1, 0]] # npimg = img.numpy().astype(np.int) # plt.imshow(np.transpose(npimg, (1, 2, 0))) # plt.show() # # import torchvision # dataiter = iter(data_loader) # images, target, _ = dataiter.next() #chwangteg target and pixel is hundreds # # imshow(torchvision.utils.make_grid(images.tensors)) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def main(): parser = argparse.ArgumentParser( description="Export model to the onnx format") parser.add_argument( "--config-file", default="configs/fcos/fcos_imprv_R_50_FPN_1x.yaml", metavar="FILE", help="path to config file", ) parser.add_argument( "--output", default="fcos.onnx", metavar="FILE", help="path to the output onnx file", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() assert cfg.MODEL.FCOS_ON, "This script is only tested for the detector FCOS." save_dir = "" logger = setup_logger("fcos_core", save_dir, get_rank()) logger.info(cfg) logger.info("Collecting env info (might take some time)") logger.info("\n" + collect_env_info()) model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) output_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir) _ = checkpointer.load(cfg.MODEL.WEIGHT) onnx_model = torch.nn.Sequential( OrderedDict([ ('backbone', model.backbone), ('heads', model.rpn.head), ])) input_names = ["input_image"] dummy_input = torch.zeros((1, 3, 800, 1216)).to(cfg.MODEL.DEVICE) output_names = [] for l in range(len(cfg.MODEL.FCOS.FPN_STRIDES)): fpn_name = "P{}/".format(3 + l) output_names.extend([ fpn_name + "logits", fpn_name + "bbox_reg", fpn_name + "centerness" ]) torch.onnx.export(onnx_model, dummy_input, args.output, verbose=True, input_names=input_names, output_names=output_names, keep_initializers_as_inputs=True) logger.info("Done. The onnx model is saved into {}.".format(args.output))
def __init__(self, cfg, local_rank, distributed): self.writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR) self.start_epoch = 0 # self.epochs = cfg.MAX_ITER / len() self.epochs = 5 model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) # 核心修改在于dataset,dataloader都是torch.utils.data.data_loader # import pdb; pdb.set_trace() # train_loader = build_single_data_loader(cfg) self.train_loader = make_train_loader( cfg, start_iter=arguments["iteration"]) # self.val_loader = make_val_loader(cfg) # train_data_loader = make_data_loader( # cfg, # is_train=True, # is_distributed=distributed, # start_iter=arguments["iteration"], # ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD self.model = model self.optimizer = optimizer self.scheduler = scheduler self.checkpointer = checkpointer self.scheduler = scheduler self.device = device self.checkpoint_period = checkpoint_period self.arguments = arguments self.distributed = distributed
def train(cfg, local_rank, distributed): writer = SummaryWriter('runs/{}'.format(cfg.OUTPUT_DIR)) ########################################################################## ############################# Initial Model ############################## ########################################################################## model = {} device = torch.device(cfg.MODEL.DEVICE) backbone = build_backbone(cfg).to(device) fcos = build_rpn(cfg, backbone.out_channels).to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone) fcos = torch.nn.SyncBatchNorm.convert_sync_batchnorm(fcos) ########################################################################## #################### Initial Optimizer and Scheduler ##################### ########################################################################## optimizer = {} optimizer["backbone"] = make_optimizer(cfg, backbone, name='backbone') optimizer["fcos"] = make_optimizer(cfg, fcos, name='fcos') scheduler = {} scheduler["backbone"] = make_lr_scheduler(cfg, optimizer["backbone"], name='backbone') scheduler["fcos"] = make_lr_scheduler(cfg, optimizer["fcos"], name='fcos') if distributed: backbone = torch.nn.parallel.DistributedDataParallel( backbone, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False ) fcos = torch.nn.parallel.DistributedDataParallel( fcos, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False ) ########################### Save Model to Dict ########################### ########################################################################## model["backbone"] = backbone model["fcos"] = fcos ########################################################################## ################################ Training ################################ ########################################################################## arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader_source( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train_base( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, cfg, run_test, distributed, writer ) return model