def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) print(model) print("") print("##############################################################") print("") #summary(model, (3, 2048, 1024)) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD #print(f"{bcolors.WARNING}Warning: No active frommets remain. Continue?{bcolors.ENDC}") do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) #summary(model,input_size=(2,3,1333,800)) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, cfg.MODEL.CLS_WEIGHT, cfg.MODEL.REG_WEIGHT, init_div=True, init_opti=False, init_model=True) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): #cfg 0 False model = build_detection_model(cfg) #实例化模型 device = torch.device(cfg.MODEL.DEVICE) #cfg.MODEL.DEVICE="cuda" 将torch.tensor分配到cuda 即GPU上 model.to(device) #将模型放在gpu上运行 if cfg.MODEL.USE_SYNCBN: #False assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) # 定义网络训练优化器 scheduler = make_lr_scheduler(cfg, optimizer) #设置学习率 if distributed: #False model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR #"." save_to_disk = get_rank() == 0 #True #checkpoint为网络的预训练模型 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) #将extra_checkpoint_data字典里的数值加入到arguments字典中 #make_data_loader data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, #False start_iter=arguments["iteration"], # 0 ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD #SOLVER.CHECKPOINT_PERIOD = 2500 do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) # import matplotlib.pyplot as plt # import numpy as np # # def imshow(img): # #img = img / 2 + 0.5 # unnormalize # img = img + 115 # img = img[[2, 1, 0]] # npimg = img.numpy().astype(np.int) # plt.imshow(np.transpose(npimg, (1, 2, 0))) # plt.show() # # import torchvision # dataiter = iter(data_loader) # images, target, _ = dataiter.next() #chwangteg target and pixel is hundreds # # imshow(torchvision.utils.make_grid(images.tensors)) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): model = build_detection_model(cfg) # 利用build_detection_model构建model device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: # syncbn是什么,SyncBatchNorm assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) # 对model进行转换,转换成sync的 optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: # 是否使用分布式训练,distributed 分布式的 model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} # 创建字典 arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk # checkpoint ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( # 利用make_data_loader读取数据 cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): writer = SummaryWriter('runs/{}'.format(cfg.OUTPUT_DIR)) ########################################################################## ############################# Initial Model ############################## ########################################################################## model = {} device = torch.device(cfg.MODEL.DEVICE) backbone = build_backbone(cfg).to(device) fcos = build_rpn(cfg, backbone.out_channels).to(device) if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P7_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P7, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P6_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P6, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P5_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P5, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P4_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P4, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3 = FCOSDiscriminator( num_convs=cfg.MODEL.ADV.DIS_P3_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P3, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P7_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P7, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P6_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P6, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P5_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P5, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P4_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P4, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_CA = FCOSDiscriminator_CA( num_convs=cfg.MODEL.ADV.CA_DIS_P3_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P3, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P7_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P7, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P6_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P6, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P5_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P5, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P4_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P4, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_Cond = FCOSDiscriminator_CondA( num_convs=cfg.MODEL.ADV.COND_DIS_P3_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P3, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN, class_align=cfg.MODEL.ADV.COND_CLASS, reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT, reg_top_align=cfg.MODEL.ADV.COND_REG.TOP, expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device) if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P7_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P7, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P6_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P6, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P5_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P5, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P4_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P4, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_HA = FCOSDiscriminator_HA( num_convs=cfg.MODEL.ADV.HA_DIS_P3_NUM_CONVS, grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P3, center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT, grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone) fcos = torch.nn.SyncBatchNorm.convert_sync_batchnorm(fcos) if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P7) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P6) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P5) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P4) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P3) if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P7_CA) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P6_CA) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P5_CA) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P4_CA) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P3_CA) if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P7_Cond) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P6_Cond) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P5_Cond) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P4_Cond) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P3_Cond) if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P7_HA) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P6_HA) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P5_HA) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P4_HA) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm( dis_P3_HA) ########################################################################## #################### Initial Optimizer and Scheduler ##################### ########################################################################## optimizer = {} optimizer["backbone"] = make_optimizer(cfg, backbone, name='backbone') optimizer["fcos"] = make_optimizer(cfg, fcos, name='fcos') if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: optimizer["dis_P7"] = make_optimizer(cfg, dis_P7, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: optimizer["dis_P6"] = make_optimizer(cfg, dis_P6, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: optimizer["dis_P5"] = make_optimizer(cfg, dis_P5, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: optimizer["dis_P4"] = make_optimizer(cfg, dis_P4, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: optimizer["dis_P3"] = make_optimizer(cfg, dis_P3, name='discriminator') if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: optimizer["dis_P7_CA"] = make_optimizer(cfg, dis_P7_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: optimizer["dis_P6_CA"] = make_optimizer(cfg, dis_P6_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: optimizer["dis_P5_CA"] = make_optimizer(cfg, dis_P5_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: optimizer["dis_P4_CA"] = make_optimizer(cfg, dis_P4_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: optimizer["dis_P3_CA"] = make_optimizer(cfg, dis_P3_CA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: optimizer["dis_P7_Cond"] = make_optimizer(cfg, dis_P7_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: optimizer["dis_P6_Cond"] = make_optimizer(cfg, dis_P6_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: optimizer["dis_P5_Cond"] = make_optimizer(cfg, dis_P5_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: optimizer["dis_P4_Cond"] = make_optimizer(cfg, dis_P4_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: optimizer["dis_P3_Cond"] = make_optimizer(cfg, dis_P3_Cond, name='discriminator') if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: optimizer["dis_P7_HA"] = make_optimizer(cfg, dis_P7_HA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: optimizer["dis_P6_HA"] = make_optimizer(cfg, dis_P6_HA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: optimizer["dis_P5_HA"] = make_optimizer(cfg, dis_P5_HA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: optimizer["dis_P4_HA"] = make_optimizer(cfg, dis_P4_HA, name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: optimizer["dis_P3_HA"] = make_optimizer(cfg, dis_P3_HA, name='discriminator') scheduler = {} scheduler["backbone"] = make_lr_scheduler(cfg, optimizer["backbone"], name='backbone') scheduler["fcos"] = make_lr_scheduler(cfg, optimizer["fcos"], name='fcos') if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: scheduler["dis_P7"] = make_lr_scheduler(cfg, optimizer["dis_P7"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: scheduler["dis_P6"] = make_lr_scheduler(cfg, optimizer["dis_P6"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: scheduler["dis_P5"] = make_lr_scheduler(cfg, optimizer["dis_P5"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: scheduler["dis_P4"] = make_lr_scheduler(cfg, optimizer["dis_P4"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: scheduler["dis_P3"] = make_lr_scheduler(cfg, optimizer["dis_P3"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: scheduler["dis_P7_CA"] = make_lr_scheduler(cfg, optimizer["dis_P7_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: scheduler["dis_P6_CA"] = make_lr_scheduler(cfg, optimizer["dis_P6_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: scheduler["dis_P5_CA"] = make_lr_scheduler(cfg, optimizer["dis_P5_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: scheduler["dis_P4_CA"] = make_lr_scheduler(cfg, optimizer["dis_P4_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: scheduler["dis_P3_CA"] = make_lr_scheduler(cfg, optimizer["dis_P3_CA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: scheduler["dis_P7_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P7_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: scheduler["dis_P6_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P6_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: scheduler["dis_P5_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P5_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: scheduler["dis_P4_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P4_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: scheduler["dis_P3_Cond"] = make_lr_scheduler( cfg, optimizer["dis_P3_Cond"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: scheduler["dis_P7_HA"] = make_lr_scheduler(cfg, optimizer["dis_P7_HA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P6: scheduler["dis_P6_HA"] = make_lr_scheduler(cfg, optimizer["dis_P6_HA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P5: scheduler["dis_P5_HA"] = make_lr_scheduler(cfg, optimizer["dis_P5_HA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P4: scheduler["dis_P4_HA"] = make_lr_scheduler(cfg, optimizer["dis_P4_HA"], name='discriminator') if cfg.MODEL.ADV.USE_DIS_P3: scheduler["dis_P3_HA"] = make_lr_scheduler(cfg, optimizer["dis_P3_HA"], name='discriminator') ########################################################################## ######################## DistributedDataParallel ######################### ########################################################################## if distributed: backbone = torch.nn.parallel.DistributedDataParallel( backbone, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) fcos = torch.nn.parallel.DistributedDataParallel( fcos, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7 = torch.nn.parallel.DistributedDataParallel( dis_P7, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6 = torch.nn.parallel.DistributedDataParallel( dis_P6, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5 = torch.nn.parallel.DistributedDataParallel( dis_P5, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4 = torch.nn.parallel.DistributedDataParallel( dis_P4, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3 = torch.nn.parallel.DistributedDataParallel( dis_P3, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_CA = torch.nn.parallel.DistributedDataParallel( dis_P7_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_CA = torch.nn.parallel.DistributedDataParallel( dis_P6_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_CA = torch.nn.parallel.DistributedDataParallel( dis_P5_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_CA = torch.nn.parallel.DistributedDataParallel( dis_P4_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_CA = torch.nn.parallel.DistributedDataParallel( dis_P3_CA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_Cond = torch.nn.parallel.DistributedDataParallel( dis_P7_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_Cond = torch.nn.parallel.DistributedDataParallel( dis_P6_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_Cond = torch.nn.parallel.DistributedDataParallel( dis_P5_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_Cond = torch.nn.parallel.DistributedDataParallel( dis_P4_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_Cond = torch.nn.parallel.DistributedDataParallel( dis_P3_Cond, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: dis_P7_HA = torch.nn.parallel.DistributedDataParallel( dis_P7_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P6: dis_P6_HA = torch.nn.parallel.DistributedDataParallel( dis_P6_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P5: dis_P5_HA = torch.nn.parallel.DistributedDataParallel( dis_P5_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P4: dis_P4_HA = torch.nn.parallel.DistributedDataParallel( dis_P4_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) if cfg.MODEL.ADV.USE_DIS_P3: dis_P3_HA = torch.nn.parallel.DistributedDataParallel( dis_P3_HA, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False) ########################################################################## ########################### Save Model to Dict ########################### ########################################################################## model["backbone"] = backbone model["fcos"] = fcos if cfg.MODEL.ADV.USE_DIS_GLOBAL: if cfg.MODEL.ADV.USE_DIS_P7: model["dis_P7"] = dis_P7 if cfg.MODEL.ADV.USE_DIS_P6: model["dis_P6"] = dis_P6 if cfg.MODEL.ADV.USE_DIS_P5: model["dis_P5"] = dis_P5 if cfg.MODEL.ADV.USE_DIS_P4: model["dis_P4"] = dis_P4 if cfg.MODEL.ADV.USE_DIS_P3: model["dis_P3"] = dis_P3 if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE: if cfg.MODEL.ADV.USE_DIS_P7: model["dis_P7_CA"] = dis_P7_CA if cfg.MODEL.ADV.USE_DIS_P6: model["dis_P6_CA"] = dis_P6_CA if cfg.MODEL.ADV.USE_DIS_P5: model["dis_P5_CA"] = dis_P5_CA if cfg.MODEL.ADV.USE_DIS_P4: model["dis_P4_CA"] = dis_P4_CA if cfg.MODEL.ADV.USE_DIS_P3: model["dis_P3_CA"] = dis_P3_CA if cfg.MODEL.ADV.USE_DIS_CONDITIONAL: if cfg.MODEL.ADV.USE_DIS_P7: model["dis_P7_Cond"] = dis_P7_Cond if cfg.MODEL.ADV.USE_DIS_P6: model["dis_P6_Cond"] = dis_P6_Cond if cfg.MODEL.ADV.USE_DIS_P5: model["dis_P5_Cond"] = dis_P5_Cond if cfg.MODEL.ADV.USE_DIS_P4: model["dis_P4_Cond"] = dis_P4_Cond if cfg.MODEL.ADV.USE_DIS_P3: model["dis_P3_Cond"] = dis_P3_Cond if cfg.MODEL.ADV.USE_DIS_HEAD: if cfg.MODEL.ADV.USE_DIS_P7: model["dis_P7_HA"] = dis_P7_HA if cfg.MODEL.ADV.USE_DIS_P6: model["dis_P6_HA"] = dis_P6_HA if cfg.MODEL.ADV.USE_DIS_P5: model["dis_P5_HA"] = dis_P5_HA if cfg.MODEL.ADV.USE_DIS_P4: model["dis_P4_HA"] = dis_P4_HA if cfg.MODEL.ADV.USE_DIS_P3: model["dis_P3_HA"] = dis_P3_HA ########################################################################## ################################ Training ################################ ########################################################################## arguments = {} arguments["iteration"] = 0 arguments["use_dis_global"] = cfg.MODEL.ADV.USE_DIS_GLOBAL arguments["use_dis_ca"] = cfg.MODEL.ADV.USE_DIS_CENTER_AWARE arguments["use_dis_conditional"] = cfg.MODEL.ADV.USE_DIS_CONDITIONAL arguments["use_dis_ha"] = cfg.MODEL.ADV.USE_DIS_HEAD arguments["ga_dis_lambda"] = cfg.MODEL.ADV.GA_DIS_LAMBDA arguments["ca_dis_lambda"] = cfg.MODEL.ADV.CA_DIS_LAMBDA arguments["cond_dis_lambda"] = cfg.MODEL.ADV.COND_DIS_LAMBDA arguments["ha_dis_lambda"] = cfg.MODEL.ADV.HA_DIS_LAMBDA arguments["use_feature_layers"] = [] if cfg.MODEL.ADV.USE_DIS_P7: arguments["use_feature_layers"].append("P7") if cfg.MODEL.ADV.USE_DIS_P6: arguments["use_feature_layers"].append("P6") if cfg.MODEL.ADV.USE_DIS_P5: arguments["use_feature_layers"].append("P5") if cfg.MODEL.ADV.USE_DIS_P4: arguments["use_feature_layers"].append("P4") if cfg.MODEL.ADV.USE_DIS_P3: arguments["use_feature_layers"].append("P3") output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(f=cfg.MODEL.WEIGHT, load_dis=True, load_opt_sch=False) # arguments.update(extra_checkpoint_data) # Initial dataloader (both target and source domain) data_loader = {} data_loader["source"] = make_data_loader_source( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) data_loader["target"] = make_data_loader_target( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train(model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, cfg, run_test, distributed, writer) return model
def train(cfg, local_rank, distributed, iter_clear, ignore_head): model = build_detection_model(cfg) # model, conversion_count = convert_to_shift_dbg( # model, # cfg.DEEPSHIFT_DEPTH, # cfg.DEEPSHIFT_TYPE, # convert_weights=True, # use_kernel=cfg.DEEPSHIFT_USEKERNEL, # rounding=cfg.DEEPSHIFT_ROUNDING, # shift_range=cfg.DEEPSHIFT_RANGE) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 if iter_clear: load_opt = False load_sch = False else: load_opt = True load_sch = True if ignore_head: load_body = True load_fpn = True load_head = False else: load_body = True load_fpn = True load_head = True # 预加载模型或者是通常的模型,或者是deepshift模型 if cfg.MODEL.WEIGHT: checkpointer = DetectronCheckpointer( cfg, model, None, None, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, load_opt=False, load_sch=False, load_body=load_body, load_fpn=load_fpn, load_head=load_head) model, conversion_count = convert_to_shift( model, cfg.DEEPSHIFT_DEPTH, cfg.DEEPSHIFT_TYPE, convert_weights=True, use_kernel=cfg.DEEPSHIFT_USEKERNEL, rounding=cfg.DEEPSHIFT_ROUNDING, shift_range=cfg.DEEPSHIFT_RANGE) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) else: model, conversion_count = convert_to_shift( model, cfg.DEEPSHIFT_DEPTH, cfg.DEEPSHIFT_TYPE, convert_weights=True, use_kernel=cfg.DEEPSHIFT_USEKERNEL, rounding=cfg.DEEPSHIFT_ROUNDING, shift_range=cfg.DEEPSHIFT_RANGE) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, load_opt=False, load_sch=False, load_body=load_body, load_fpn=load_fpn, load_head=load_head) conv2d_layers_count = count_layer_type(model, torch.nn.Conv2d) linear_layers_count = count_layer_type(model, torch.nn.Linear) print("###### conversion_count: {}, not convert conv2d layer: {}, linear layer: {}".format( conversion_count, conv2d_layers_count, linear_layers_count)) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 arguments.update(extra_checkpoint_data) if iter_clear: arguments["iteration"] = 0 data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) model = round_shift_weights(model) torch.save({"model": model.state_dict()}, os.path.join(output_dir, "model_final_round.pth")) return model
def train(cfg, local_rank, distributed, iter_clear, ignore_head): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.USE_SYNCBN: assert is_pytorch_1_1_0_or_later(), \ "SyncBatchNorm is only available in pytorch >= 1.1.0" model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) if iter_clear: load_opt = False load_sch = False else: load_opt = True load_sch = True if ignore_head: load_body = True load_fpn = True load_head = False else: load_body = True load_fpn = True load_head = True extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, load_opt=load_opt, load_sch=load_sch, load_body=load_body, load_fpn=load_fpn, load_head=load_head) arguments.update(extra_checkpoint_data) if iter_clear: arguments["iteration"] = 0 data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model