def train(cfg, local_rank, distributed): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) # arguments["iteration"] = 0 checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD if cfg.MODEL.DOMAIN_ADAPTATION_ON: source_data_loader = make_data_loader( cfg, is_train=True, is_source=True, is_distributed=distributed, start_iter=arguments["iteration"], ) target_data_loader = make_data_loader( cfg, is_train=True, is_source=False, is_distributed=distributed, start_iter=arguments["iteration"], ) do_da_train( model, source_data_loader, target_data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, cfg, ) else: data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model
def train(cfg, local_rank, distributed): # 构建模型 model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) # 通过这种方法来得到optimizer以及lr_scheduler optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: # 如果是分布式训练,这里我们不用管 model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 # output_dir的值是'.' output_dir = cfg.OUTPUT_DIR # 结果为True save_to_disk = get_rank() == 0 # 这个checkpoint类似于faster里面的,可以使用checkpoint处加载数据 # 调用的这个DetectronCheckpointer文件的内容都没有动 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) # cfg.model.weight是catalog://ImageNetPretrained/MSRA/R-50 extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) # 这个update可以更新已经有的key,所以下面dataloader时候,arguments["iteration"]也会改变 arguments.update(extra_checkpoint_data) # 这个的值是2500 checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD if cfg.MODEL.DOMAIN_ADAPTATION_ON: source_data_loader = make_data_loader( cfg, is_train=True, is_source=True, is_distributed=distributed, start_iter=arguments["iteration"], ) target_data_loader = make_data_loader( cfg, is_train=True, is_source=False, is_distributed=distributed, start_iter=arguments["iteration"], ) do_da_train( model, source_data_loader, target_data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, cfg, ) else: data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=arguments["iteration"], ) do_train( model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ) return model