def finetune_first_image(model, images, targets, optimizer,scheduler, logger, cfg): total_iter_finetune = cfg.FINETUNE.TOTAL_ITER model.train() meters = MetricLogger(delimiter=" ") for iteration in range(total_iter_finetune): scheduler.step() loss_dict, _ = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(total_loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() losses.backward() optimizer.step() meters.update(lr=optimizer.param_groups[0]["lr"]) if iteration % (total_iter_finetune / 2) == 0 : logger.info( meters.delimiter.join( [ "{meters}", ] ).format( meters=str(meters), ) ) model.eval() return model
def trainIters(args): """ main: loop over different epoch. and datasplit """ epoch_resume = args.epoch_resume model_dir = os.path.join(args.models_root, args.model_name) board_dir = os.path.join(args.models_root, 'boards', args.model_name) if args.local_rank == 0: make_dir(board_dir) make_dir(model_dir) start = time.time() meters = { args.train_split: MetricLogger(delimiter=" "), args.eval_split: MetricLogger(delimiter=" ") } args.model_dir = model_dir enc_opt, dec_opt, trainer = build_model(args) max_eval_iter = args.max_eval_iter # save parameters for future use if args.local_rank == 0: tb_writer = SummaryWriter(board_dir) pickle.dump(args, open(os.path.join(model_dir, timestr + '_args.pkl'), 'wb')) pickle.dump(args, open(os.path.join(model_dir, 'args.pkl'), 'wb')) # overwrite the latest args logging.info('save args in %s' % os.path.join(model_dir, timestr + 'args.pkl')) logging.info('{}'.format(args)) start = time.time() # vars for early stopping best_val_loss = args.best_val_loss best_val_epo = 0 acc_patience = 0 mt_val = -1 # keep track of the number of batches in each epoch for continuity when plotting curves if args.local_rank == 0: logging.info('init_dataloaders') start = time.time() loaders = init_dataloaders(args) num_batches = {args.train_split: 0, args.eval_split: 0} if args.local_rank == 0: logging.info('dataloader %.3f' % (time.time() - start)) for e in range(args.max_epoch): # check if it's time to do some changes here if e + epoch_resume >= args.finetune_after and not args.sample_inference_mask and not args.finetune_after == -1: args.sample_inference_mask = 1 logging.info('=' * 10 + '> start sample_inference_mask') acc_patience, best_val_loss = 0, 0 # in current epoch, loop over split # we validate after each epoch if max_eval_iter > 0 and e == 0: splits = [args.eval_split, args.train_split, args.eval_split] elif max_eval_iter == 0: splits = [args.train_split] else: splits = [args.train_split, args.eval_split] for split in splits: if split == args.eval_split: trainer.eval() # loop over batches in current epoch if args.local_rank == 0: logging.info('epoch %d - %s; ' % (e + epoch_resume, split)) logging.info( '-- loss weight loss_weight_match: {} loss_weight_iouraw {}; ' .format(args.loss_weight_match, args.loss_weight_iouraw)) sd = time.time() start = time.time() iter_time = [] for batch_idx, (inputs, imgs_names, targets, seq_name, starting_frame) in enumerate(loaders[split]): # imgs_names: can be proposals: List[tuple(BoxList)], len of list=Nframe, len-of-tuple=BatchSize if args.local_rank == 0: start_iter = time.time() dataT = time.time() - sd assert (type(targets) == list) inputs = [sub.to(args.device) for sub in inputs] targets = [sub.to(args.device) for sub in targets] if args.load_proposals_dataset: proposals_cur_batch = imgs_names # len=framelen proposals = [] # BoxList of current batch for p in proposals_cur_batch: boxlist = list(p) # BoxList of current batch proposals.append([b.to(args.device) for b in boxlist]) # len=BatchSize imgs_names = None else: proposals = None # forward if split == args.eval_split: with torch.no_grad(): loss, losses = trainer(batch_idx, inputs, imgs_names, targets, seq_name, starting_frame, split, args, proposals) else: loss, losses = trainer(batch_idx, inputs, imgs_names, targets, seq_name, starting_frame, split, args, proposals) ## import pdb; pdb.set_trace() #if DEBUG: # logging.info('>> profile ') # logging.info('seq_name {}, inputs sum {}; proposals: {} imgs_names {}'.format(seq_name, inputs[0].sum(), proposals[0][0].bbox.sum(), imgs_names)) # info = {'batch_idx': batch_idx, 'info':[seq_name, inputs[0].shape, inputs[0].sum(), proposals[0][0].bbox.sum(), imgs_names, losses, loss]} # check_info = torch.load('../../drvos/src/debug/%d.pth'%batch_idx) # CHECKDEBUG(info, check_info) loss = loss.mean() #reduce_loss_dict({'loss':loss}) if split == args.train_split: # and args.local_rank == 0: dec_opt.zero_grad() enc_opt.zero_grad() if loss.requires_grad: loss.backward() if args.distributed: average_gradients(trainer, args.local_rank) torch.cuda.synchronize() dec_opt.step() enc_opt.step() # record the losses # store loss values in dictionary separately if args.distributed: losses = reduce_loss_dict(losses) if args.ngpus > 1 and args.local_rank == 0: for k, v in losses.items(): if not args.distributed: losses[k] = v.mean() tb_writer.add_scalar( '%s/%s' % (k, split), losses[k], batch_idx + (e + epoch_resume) * len(loaders[split])) elif args.local_rank == 0: for k, v in losses.items(): tb_writer.add_scalar( '%s/%s' % (k, split), v, batch_idx + (e + epoch_resume) * len(loaders[split])) if args.local_rank == 0: meters[split].update(**losses) # print after some iterations if ( batch_idx + 1 ) % args.print_every == 0 and args.local_rank == 0: # iteration te = time.time() - start_iter iter_time.append(te) remain_t = ( sum(iter_time) / len(iter_time) * (len(loaders[split]) - batch_idx)) / 60.0 / 60.0 max_mem = "mem: {memory:.0f}".format( memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) meters[split].update(time=te, dt=dataT) logging.info("%s:%s:p%d(%d-%.2f):E%d it%d/%d: rt(%.2fh) %s|%s"%(args.model_name, split, acc_patience, \ best_val_epo, best_val_loss, (e+epoch_resume), batch_idx, len(loaders[split]), remain_t, \ str(meters[split]), max_mem)) start = time.time() if args.local_rank == 0 and split == args.train_split and (((batch_idx + 1) % args.save_every == 0) \ or batch_idx + 1 == len(loaders[split])): logging.info('save model at {} {}'.format( batch_idx, e + epoch_resume)) save_checkpoint_iter( trainer, args, 'epo%02d_iter%05d' % (e + epoch_resume, batch_idx), enc_opt, dec_opt) sd = time.time() # out of for-all-batches in current split num_batches[split] = batch_idx + 1 # for loss_name in ['loss', 'match_loss', 'iou', 'iouraw', 'hard_iou_raw']: if split == args.eval_split: for loss_name in ['hard_iou_raw', 'hard_iou' ]: # prefer hard_iou than hard_iou_raw if loss_name in meters[ args.eval_split].fields() and max_eval_iter != 0: mt_val = meters[args.eval_split].load_field( loss_name).global_avg meters[args.eval_split] = MetricLogger(delimiter=" ") if mt_val > (best_val_loss + args.min_delta): logging.info("Saving checkpoint.") best_val_loss = mt_val best_val_epo = e + epoch_resume # saves model, params, and optimizers save_checkpoint_iter( trainer, args, 'best_%.3f_epo%02d' % (best_val_loss, e + epoch_resume), enc_opt, dec_opt) acc_patience = 0 else: acc_patience += 1 if acc_patience > args.patience_stop: logging.info('acc_patience reach maximum, I killed my self: Bye ') break
def train(cfg, local_rank, distributed, logger): if is_main_process(): wandb.init(project='scene-graph', entity='sgg-speaker-listener', config=cfg.LISTENER) debug_print(logger, 'prepare training') model = build_detection_model(cfg) listener = build_listener(cfg) speaker_listener = SpeakerListener(model, listener, cfg, is_joint=cfg.LISTENER.JOINT) if is_main_process(): wandb.watch(listener) debug_print(logger, 'end model construction') # modules that should be always set in eval mode # their eval() method should be called after model.train() is called eval_modules = ( model.rpn, model.backbone, model.roi_heads.box, ) fix_eval_modules(eval_modules) # NOTE, we slow down the LR of the layers start with the names in slow_heads if cfg.MODEL.ROI_RELATION_HEAD.PREDICTOR == "IMPPredictor": slow_heads = [ "roi_heads.relation.box_feature_extractor", "roi_heads.relation.union_feature_extractor.feature_extractor", ] else: slow_heads = [] # load pretrain layers to new layers load_mapping = { "roi_heads.relation.box_feature_extractor": "roi_heads.box.feature_extractor", "roi_heads.relation.union_feature_extractor.feature_extractor": "roi_heads.box.feature_extractor" } if cfg.MODEL.ATTRIBUTE_ON: load_mapping[ "roi_heads.relation.att_feature_extractor"] = "roi_heads.attribute.feature_extractor" load_mapping[ "roi_heads.relation.union_feature_extractor.att_feature_extractor"] = "roi_heads.attribute.feature_extractor" device = torch.device(cfg.MODEL.DEVICE) model.to(device) listener.to(device) num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 num_batch = cfg.SOLVER.IMS_PER_BATCH optimizer = make_optimizer(cfg, model, logger, slow_heads=slow_heads, slow_ratio=10.0, rl_factor=float(num_batch)) listener_optimizer = make_listener_optimizer(cfg, listener) scheduler = make_lr_scheduler(cfg, optimizer, logger) listener_scheduler = None debug_print(logger, 'end optimizer and schedule') if cfg.LISTENER.JOINT: speaker_listener_optimizer = make_speaker_listener_optimizer( cfg, speaker_listener.speaker, speaker_listener.listener) # Initialize mixed-precision training use_mixed_precision = cfg.DTYPE == "float16" amp_opt_level = 'O1' if use_mixed_precision else 'O0' if cfg.LISTENER.JOINT: speaker_listener, speaker_listener_optimizer = amp.initialize( speaker_listener, speaker_listener_optimizer, opt_level='O0') else: speaker_listener, listener_optimizer = amp.initialize( speaker_listener, listener_optimizer, opt_level='O0') #listener, listener_optimizer = amp.initialize(listener, listener_optimizer, opt_level='O0') #[model, listener], [optimizer, listener_optimizer] = amp.initialize([model, listener], [optimizer, listener_optimizer], opt_level='O1', loss_scale=1) #model = amp.initialize(model, opt_level='O1') if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, find_unused_parameters=True, ) listener = torch.nn.parallel.DistributedDataParallel( listener, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, find_unused_parameters=True, ) debug_print(logger, 'end distributed') arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR listener_dir = cfg.LISTENER_DIR save_to_disk = get_rank() == 0 speaker_checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk, custom_scheduler=True) listener_checkpointer = Checkpointer(listener, optimizer=listener_optimizer, save_dir=listener_dir, save_to_disk=save_to_disk, custom_scheduler=False) speaker_listener.add_listener_checkpointer(listener_checkpointer) speaker_listener.add_speaker_checkpointer(speaker_checkpointer) speaker_listener.load_listener() speaker_listener.load_speaker(load_mapping=load_mapping) debug_print(logger, 'end load checkpointer') train_data_loader = make_data_loader(cfg, mode='train', is_distributed=distributed, start_iter=arguments["iteration"], ret_images=True) val_data_loaders = make_data_loader(cfg, mode='val', is_distributed=distributed, ret_images=True) debug_print(logger, 'end dataloader') checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD if cfg.SOLVER.PRE_VAL: logger.info("Validate before training") #output = run_val(cfg, model, listener, val_data_loaders, distributed, logger) #print('OUTPUT: ', output) #(sg_loss, img_loss, sg_acc, img_acc) = output logger.info("Start training") meters = MetricLogger(delimiter=" ") max_iter = len(train_data_loader) start_iter = arguments["iteration"] start_training_time = time.time() end = time.time() print_first_grad = True listener_loss_func = torch.nn.MarginRankingLoss(margin=1, reduction='none') mistake_saver = None if is_main_process(): ds_catalog = DatasetCatalog() dict_file_path = os.path.join( ds_catalog.DATA_DIR, ds_catalog.DATASETS['VG_stanford_filtered_with_attribute'] ['dict_file']) ind_to_classes, ind_to_predicates = load_vg_info(dict_file_path) ind_to_classes = {k: v for k, v in enumerate(ind_to_classes)} ind_to_predicates = {k: v for k, v in enumerate(ind_to_predicates)} print('ind to classes:', ind_to_classes, '/n ind to predicates:', ind_to_predicates) mistake_saver = MistakeSaver( '/Scene-Graph-Benchmark.pytorch/filenames_masked', ind_to_classes, ind_to_predicates) #is_printed = False while True: try: listener_iteration = 0 for iteration, (images, targets, image_ids) in enumerate(train_data_loader, start_iter): if cfg.LISTENER.JOINT: speaker_listener_optimizer.zero_grad() else: listener_optimizer.zero_grad() #print(f'ITERATION NUMBER: {iteration}') if any(len(target) < 1 for target in targets): logger.error( f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" ) if len(images) <= 1: continue data_time = time.time() - end iteration = iteration + 1 listener_iteration += 1 arguments["iteration"] = iteration model.train() fix_eval_modules(eval_modules) images_list = deepcopy(images) images_list = to_image_list( images_list, cfg.DATALOADER.SIZE_DIVISIBILITY).to(device) for i in range(len(images)): images[i] = images[i].unsqueeze(0) images[i] = F.interpolate(images[i], size=(224, 224), mode='bilinear', align_corners=False) images[i] = images[i].squeeze() images = torch.stack(images).to(device) #images.requires_grad_() targets = [target.to(device) for target in targets] speaker_loss_dict = {} if not cfg.LISTENER.JOINT: score_matrix = speaker_listener(images_list, targets, images) else: score_matrix, _, speaker_loss_dict = speaker_listener( images_list, targets, images) speaker_summed_losses = sum( loss for loss in speaker_loss_dict.values()) # reduce losses over all GPUs for logging purposes if not not cfg.LISTENER.JOINT: speaker_loss_dict_reduced = reduce_loss_dict( speaker_loss_dict) speaker_losses_reduced = sum( loss for loss in speaker_loss_dict_reduced.values()) speaker_losses_reduced /= num_gpus if is_main_process(): wandb.log( {"Train Speaker Loss": speaker_losses_reduced}, listener_iteration) listener_loss = 0 gap_reward = 0 avg_acc = 0 num_correct = 0 score_matrix = score_matrix.to(device) # fill loss matrix loss_matrix = torch.zeros((2, images.size(0), images.size(0)), device=device) # sg centered scores for true_index in range(loss_matrix.size(1)): row_score = score_matrix[true_index] (true_scores, predicted_scores, binary) = format_scores(row_score, true_index, device) loss_vec = listener_loss_func(true_scores, predicted_scores, binary) loss_matrix[0][true_index] = loss_vec # image centered scores transposted_score_matrix = score_matrix.t() for true_index in range(loss_matrix.size(1)): row_score = transposted_score_matrix[true_index] (true_scores, predicted_scores, binary) = format_scores(row_score, true_index, device) loss_vec = listener_loss_func(true_scores, predicted_scores, binary) loss_matrix[1][true_index] = loss_vec print('iteration:', listener_iteration) sg_acc = 0 img_acc = 0 # calculate accuracy for i in range(loss_matrix.size(1)): temp_sg_acc = 0 temp_img_acc = 0 for j in range(loss_matrix.size(2)): if loss_matrix[0][i][i] > loss_matrix[0][i][j]: temp_sg_acc += 1 else: if cfg.LISTENER.HTML: if is_main_process( ) and listener_iteration >= 600 and listener_iteration % 25 == 0 and i != j: detached_sg_i = (sgs[i][0].detach(), sgs[i][1], sgs[i][2].detach()) detached_sg_j = (sgs[j][0].detach(), sgs[j][1], sgs[j][2].detach()) mistake_saver.add_mistake( (image_ids[i], image_ids[j]), (detached_sg_i, detached_sg_j), listener_iteration, 'SG') if loss_matrix[1][i][i] > loss_matrix[1][j][i]: temp_img_acc += 1 else: if cfg.LISTENER.HTML: if is_main_process( ) and listener_iteration >= 600 and listener_iteration % 25 == 0 and i != j: detached_sg_i = (sgs[i][0].detach(), sgs[i][1], sgs[i][2].detach()) detached_sg_j = (sgs[j][0].detach(), sgs[j][1], sgs[j][2].detach()) mistake_saver.add_mistake( (image_ids[i], image_ids[j]), (detached_sg_i, detached_sg_j), listener_iteration, 'IMG') temp_sg_acc = temp_sg_acc * 100 / (loss_matrix.size(1) - 1) temp_img_acc = temp_img_acc * 100 / (loss_matrix.size(1) - 1) sg_acc += temp_sg_acc img_acc += temp_img_acc if cfg.LISTENER.HTML: if is_main_process( ) and listener_iteration % 100 == 0 and listener_iteration >= 600: mistake_saver.toHtml('/www') sg_acc /= loss_matrix.size(1) img_acc /= loss_matrix.size(1) avg_sg_acc = torch.tensor([sg_acc]).to(device) avg_img_acc = torch.tensor([img_acc]).to(device) # reduce acc over all gpus avg_acc = {'sg_acc': avg_sg_acc, 'img_acc': avg_img_acc} avg_acc_reduced = reduce_loss_dict(avg_acc) sg_acc = sum(acc for acc in avg_acc_reduced['sg_acc']) img_acc = sum(acc for acc in avg_acc_reduced['img_acc']) # log acc to wadb if is_main_process(): wandb.log({ "Train SG Accuracy": sg_acc.item(), "Train IMG Accuracy": img_acc.item() }) sg_loss = 0 img_loss = 0 for i in range(loss_matrix.size(0)): for j in range(loss_matrix.size(1)): loss_matrix[i][j][j] = 0. for i in range(loss_matrix.size(1)): sg_loss += torch.max(loss_matrix[0][i]) img_loss += torch.max(loss_matrix[1][:][i]) sg_loss = sg_loss / loss_matrix.size(1) img_loss = img_loss / loss_matrix.size(1) sg_loss = sg_loss.to(device) img_loss = img_loss.to(device) loss_dict = {'sg_loss': sg_loss, 'img_loss': img_loss} losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) sg_loss_reduced = loss_dict_reduced['sg_loss'] img_loss_reduced = loss_dict_reduced['img_loss'] if is_main_process(): wandb.log({"Train SG Loss": sg_loss_reduced}) wandb.log({"Train IMG Loss": img_loss_reduced}) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) losses = losses + speaker_summed_losses * cfg.LISTENER.LOSS_COEF # Note: If mixed precision is not used, this ends up doing nothing # Otherwise apply loss scaling for mixed-precision recipe #losses.backward() if not cfg.LISTENER.JOINT: with amp.scale_loss(losses, listener_optimizer) as scaled_losses: scaled_losses.backward() else: with amp.scale_loss( losses, speaker_listener_optimizer) as scaled_losses: scaled_losses.backward() verbose = (iteration % cfg.SOLVER.PRINT_GRAD_FREQ ) == 0 or print_first_grad # print grad or not print_first_grad = False #clip_grad_value([(n, p) for n, p in listener.named_parameters() if p.requires_grad], cfg.LISTENER.CLIP_VALUE, logger=logger, verbose=True, clip=True) if not cfg.LISTENER.JOINT: listener_optimizer.step() else: speaker_listener_optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if cfg.LISTENER.JOINT: if iteration % 200 == 0 or iteration == max_iter: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=speaker_listener_optimizer.param_groups[-1] ["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, )) else: if iteration % 200 == 0 or iteration == max_iter: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=listener_optimizer.param_groups[-1]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, )) if iteration % checkpoint_period == 0: """ print('Model before save') print('****************************') print(listener.gnn.conv1.node_model.node_mlp_1[0].weight) print('****************************') """ if not cfg.LISTENER.JOINT: listener_checkpointer.save( "model_{:07d}".format(listener_iteration), amp=amp.state_dict()) else: speaker_checkpointer.save( "model_speaker{:07d}".format(iteration)) listener_checkpointer.save( "model_listenr{:07d}".format(listener_iteration), amp=amp.state_dict()) if iteration == max_iter: if not cfg.LISTENER.JOINT: listener_checkpointer.save( "model_{:07d}".format(listener_iteration), amp=amp.state_dict()) else: speaker_checkpointer.save( "model_{:07d}".format(iteration)) listener_checkpointer.save( "model_{:07d}".format(listener_iteration), amp=amp.state_dict()) val_result = None # used for scheduler updating if cfg.SOLVER.TO_VAL and iteration % cfg.SOLVER.VAL_PERIOD == 0: logger.info("Start validating") val_result = run_val(cfg, model, listener, val_data_loaders, distributed, logger) (sg_loss, img_loss, sg_acc, img_acc, speaker_val) = val_result if is_main_process(): wandb.log({ "Validation SG Accuracy": sg_acc, "Validation IMG Accuracy": img_acc, "Validation SG Loss": sg_loss, "Validation IMG Loss": img_loss, "Validation Speaker": speaker_val, }) #logger.info("Validation Result: %.4f" % val_result) except Exception as err: raise (err) print('Dataset finished, creating new') train_data_loader = make_data_loader( cfg, mode='train', is_distributed=distributed, start_iter=arguments["iteration"], ret_images=True) total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter))) return listener
def train(cfg, local_rank, distributed, logger): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) optimizer = make_optimizer(cfg, model, logger, rl_factor=float(cfg.SOLVER.IMS_PER_BATCH)) scheduler = make_lr_scheduler(cfg, optimizer) # Initialize mixed-precision training use_mixed_precision = cfg.DTYPE == "float16" amp_opt_level = 'O1' if use_mixed_precision else 'O0' model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load( cfg.MODEL.WEIGHT, update_schedule=cfg.SOLVER.UPDATE_SCHEDULE_DURING_LOAD) arguments.update(extra_checkpoint_data) train_data_loader = make_data_loader( cfg, mode='train', is_distributed=distributed, start_iter=arguments["iteration"], ) val_data_loaders = make_data_loader( cfg, mode='val', is_distributed=distributed, ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD if cfg.SOLVER.PRE_VAL: logger.info("Validate before training") run_val(cfg, model, val_data_loaders, distributed) logger.info("Start training") meters = MetricLogger(delimiter=" ") max_iter = len(train_data_loader) start_iter = arguments["iteration"] start_training_time = time.time() end = time.time() for iteration, (images, targets, _) in enumerate(train_data_loader, start_iter): model.train() if any(len(target) < 1 for target in targets): logger.error( f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" ) data_time = time.time() - end iteration = iteration + 1 arguments["iteration"] = iteration scheduler.step() images = images.to(device) targets = [target.to(device) for target in targets] loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() # Note: If mixed precision is not used, this ends up doing nothing # Otherwise apply loss scaling for mixed-precision recipe with amp.scale_loss(losses, optimizer) as scaled_losses: scaled_losses.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 200 == 0 or iteration == max_iter: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, )) if cfg.SOLVER.TO_VAL and iteration % cfg.SOLVER.VAL_PERIOD == 0: logger.info("Start validating") run_val(cfg, model, val_data_loaders, distributed) if iteration % checkpoint_period == 0: checkpointer.save("model_{:07d}".format(iteration), **arguments) if iteration == max_iter: checkpointer.save("model_final", **arguments) total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter))) return model
def train(cfg, local_rank, distributed, logger): debug_print(logger, 'prepare training') model = build_detection_model(cfg) debug_print(logger, 'end model construction') # modules that should be always set in eval mode # their eval() method should be called after model.train() is called eval_modules = (model.rpn, model.backbone, model.roi_heads.box,) fix_eval_modules(eval_modules) # NOTE, we slow down the LR of the layers start with the names in slow_heads if cfg.MODEL.ROI_RELATION_HEAD.PREDICTOR == "IMPPredictor": slow_heads = ["roi_heads.relation.box_feature_extractor", "roi_heads.relation.union_feature_extractor.feature_extractor",] else: slow_heads = [] # load pretrain layers to new layers load_mapping = {"roi_heads.relation.box_feature_extractor" : "roi_heads.box.feature_extractor", "roi_heads.relation.union_feature_extractor.feature_extractor" : "roi_heads.box.feature_extractor"} if cfg.MODEL.ATTRIBUTE_ON: load_mapping["roi_heads.relation.att_feature_extractor"] = "roi_heads.attribute.feature_extractor" load_mapping["roi_heads.relation.union_feature_extractor.att_feature_extractor"] = "roi_heads.attribute.feature_extractor" device = torch.device(cfg.MODEL.DEVICE) model.to(device) num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 num_batch = cfg.SOLVER.IMS_PER_BATCH optimizer = make_optimizer(cfg, model, logger, slow_heads=slow_heads, slow_ratio=10.0, rl_factor=float(num_batch)) scheduler = make_lr_scheduler(cfg, optimizer, logger) debug_print(logger, 'end optimizer and shcedule') # Initialize mixed-precision training use_mixed_precision = cfg.DTYPE == "float16" amp_opt_level = 'O1' if use_mixed_precision else 'O0' model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, find_unused_parameters=True, ) debug_print(logger, 'end distributed') arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk, custom_scheduler=True ) # if there is certain checkpoint in output_dir, load it, else load pretrained detector if checkpointer.has_checkpoint(): extra_checkpoint_data = checkpointer.load(cfg.MODEL.PRETRAINED_DETECTOR_CKPT, update_schedule=cfg.SOLVER.UPDATE_SCHEDULE_DURING_LOAD) arguments.update(extra_checkpoint_data) if cfg.SOLVER.UPDATE_SCHEDULE_DURING_LOAD: checkpointer.scheduler.last_epoch = extra_checkpoint_data["iteration"] logger.info("update last epoch of scheduler to iter: {}".format(str(extra_checkpoint_data["iteration"]))) else: # load_mapping is only used when we init current model from detection model. checkpointer.load(cfg.MODEL.PRETRAINED_DETECTOR_CKPT, with_optim=False, load_mapping=load_mapping) debug_print(logger, 'end load checkpointer') train_data_loader = make_data_loader( cfg, mode='train', is_distributed=distributed, start_iter=arguments["iteration"], ) val_data_loaders = make_data_loader( cfg, mode='val', is_distributed=distributed, ) debug_print(logger, 'end dataloader') checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD if cfg.SOLVER.PRE_VAL: logger.info("Validate before training") run_val(cfg, model, val_data_loaders, distributed, logger) logger.info("Start training") meters = MetricLogger(delimiter=" ") max_iter = len(train_data_loader) start_iter = arguments["iteration"] start_training_time = time.time() end = time.time() print_first_grad = True for iteration, (images, targets, _) in enumerate(train_data_loader, start_iter): if any(len(target) < 1 for target in targets): logger.error(f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" ) data_time = time.time() - end iteration = iteration + 1 arguments["iteration"] = iteration model.train() fix_eval_modules(eval_modules) images = images.to(device) targets = [target.to(device) for target in targets] loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() # Note: If mixed precision is not used, this ends up doing nothing # Otherwise apply loss scaling for mixed-precision recipe with amp.scale_loss(losses, optimizer) as scaled_losses: scaled_losses.backward() # add clip_grad_norm from MOTIFS, tracking gradient, used for debug verbose = (iteration % cfg.SOLVER.PRINT_GRAD_FREQ) == 0 or print_first_grad # print grad or not print_first_grad = False clip_grad_norm([(n, p) for n, p in model.named_parameters() if p.requires_grad], max_norm=cfg.SOLVER.GRAD_NORM_CLIP, logger=logger, verbose=verbose, clip=True) optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 200 == 0 or iteration == max_iter: logger.info( meters.delimiter.join( [ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ] ).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[-1]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, ) ) if iteration % checkpoint_period == 0: checkpointer.save("model_{:07d}".format(iteration), **arguments) if iteration == max_iter: checkpointer.save("model_final", **arguments) val_result = None # used for scheduler updating if cfg.SOLVER.TO_VAL and iteration % cfg.SOLVER.VAL_PERIOD == 0: logger.info("Start validating") val_result = run_val(cfg, model, val_data_loaders, distributed, logger) logger.info("Validation Result: %.4f" % val_result) # scheduler should be called after optimizer.step() in pytorch>=1.1.0 # https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate if cfg.SOLVER.SCHEDULE.TYPE == "WarmupReduceLROnPlateau": scheduler.step(val_result, epoch=iteration) if scheduler.stage_count >= cfg.SOLVER.SCHEDULE.MAX_DECAY_STEP: logger.info("Trigger MAX_DECAY_STEP at iteration {}.".format(iteration)) break else: scheduler.step() total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info( "Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter) ) ) return model
def do_train(model, data_loader, optimizer, scheduler, checkpointer, device, validation_period, checkpoint_period, arguments, run_validation): logger = logging.getLogger("maskrcnn_benchmark.trainer") logger.info("Start training") meters = MetricLogger(delimiter=" ") max_iter = len(data_loader) start_iter = arguments["iteration"] model.train() start_training_time = time.time() end = time.time() saved_models = {} best_metric = float("-inf") best_model_iter = None for iteration, (images, targets, _) in enumerate(data_loader, start_iter): data_time = time.time() - end batch_start = time.time() arguments["iteration"] = iteration if iteration % validation_period == 0: results = validate_and_log(model, run_validation, iteration) first_dataset_results = next(iter(results.items())) dataset_name = first_dataset_results[0] metric_name = "[email protected]" metric = first_dataset_results[1][metric_name] if metric > best_metric: logger.info( f"Found a new current best model: iter {iteration}, {metric_name} on {dataset_name} = {metric:0.4f}" ) # checkpoint the best model best_metric = metric best_model_iter = iteration model_filename = 'model_best' checkpointer.save(model_filename, **arguments) if iteration % checkpoint_period == 0: model_filename = 'model_{:07d}'.format(iteration) checkpointer.save(model_filename, **arguments) saved_models[iteration] = model_filename + '.pth' model.train() scheduler.step() images = images.to(device) targets = [target.to(device) for target in targets] loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() losses.backward() optimizer.step() batch_time = time.time() - batch_start end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, )) losses_str = meters.delimiter.join( ["Loss: {:.4f}".format(losses.item())] + [ "{0}: {1:.4f}".format(k, v.item()) for k, v in loss_dict_reduced.items() ]) logger.info(losses_str) validate_and_log(model, run_validation, arguments["iteration"]) checkpointer.save("model_final", **arguments) if max_iter > 0: total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter)))
def do_train(model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, logger, tensorboard_writer: TensorboardWriter = None): logger.info("Start training") meters = MetricLogger(delimiter=" ") max_iter = len(data_loader) start_iter = arguments["iteration"] model.train() start_training_time = time.time() end = time.time() for iteration, (images, targets, _) in enumerate(data_loader, start_iter): if any(len(target) < 1 for target in targets): logger.error( "Iteration={iteration + 1} || Image Ids used for training {_} || " "targets Length={[len(target) for target in targets]}") continue data_time = time.time() - end iteration = iteration + 1 arguments["iteration"] = iteration scheduler.step() images = images.to(device) targets = [target.to(device) for target in targets] result, loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() # Note: If mixed precision is not used, this ends up doing nothing # Otherwise apply loss scaling for mixed-precision recipe with amp.scale_loss(losses, optimizer) as scaled_losses: scaled_losses.backward() optimizer.step() # write images / ground truth / evaluation metrics to tensorboard tensorboard_writer(iteration, losses_reduced, loss_dict_reduced, images, targets) batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if get_world_size() < 2 or dist.get_rank() == 0: if iteration % 20 == 0 or iteration == max_iter: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], )) if iteration % checkpoint_period == 0: checkpointer.save("model_{:07d}".format(iteration), **arguments) if iteration == max_iter: checkpointer.save("model_final", **arguments) total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter)))