def train(): logger = logging.getLogger() is_dist = dist.is_initialized() ## dataset dl = data_factory[cfg.dataset].get_data_loader( cfg.im_root, cfg.train_im_anns, cfg.ims_per_gpu, cfg.scales, cfg.cropsize, cfg.max_iter, mode='train', distributed=is_dist) ## model net, criteria_pre, criteria_aux = set_model() ## optimizer optim = set_optimizer(net) ## mixed precision training scaler = amp.GradScaler() ## ddp training net = set_model_dist(net) ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() ## lr scheduler lr_schdr = WarmupPolyLrScheduler(optim, power=0.9, max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1,) ## train loop for it, (im, lb) in enumerate(dl): im = im.cuda() lb = lb.cuda() lb = torch.squeeze(lb, 1) optim.zero_grad() with amp.autocast(enabled=cfg.use_fp16): logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)] loss = loss_pre + sum(loss_aux) scaler.scale(loss).backward() scaler.step(optim) scaler.update() torch.cuda.synchronize() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux)] ## print training log message if (it + 1) % 100 == 0: lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) print_log_msg( it, cfg.max_iter, lr, time_meter, loss_meter, loss_pre_meter, loss_aux_meters) lr_schdr.step() ## dump the final model and evaluate the result save_pth = osp.join(cfg.respth, 'model_final.pth') logger.info('\nsave models to {}'.format(save_pth)) state = net.module.state_dict() if dist.get_rank() == 0: torch.save(state, save_pth) logger.info('\nevaluating the final model') torch.cuda.empty_cache() heads, mious = eval_model(cfg, net, 2, cfg.im_root, cfg.val_im_anns) logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl')) return
def train(): logger = logging.getLogger() is_dist = False ## dataset dl = get_data_loader( cfg.im_root, cfg.train_im_anns, cfg.ims_per_gpu, cfg.scales, cfg.cropsize, cfg.max_iter, mode='train', distributed=is_dist) valid = get_data_loader( cfg.im_root, cfg.val_im_anns, cfg.ims_per_gpu, cfg.scales, cfg.cropsize, cfg.max_iter, mode='val', distributed=is_dist ) ## model net, criteria_pre, criteria_aux = set_model() print(net) print(f'n_parameters: {sum(p.numel() for p in net.parameters())}') ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() ## lr scheduler lr_schdr = WarmupPolyLrScheduler(optim, power=0.9, max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1,) best_validation = np.inf for i in range(cfg.n_epochs): ## train loop for it, (im, lb) in enumerate(Bar(dl)): net.train() im = im.cuda() lb = lb.cuda() lb = torch.squeeze(lb, 1) optim.zero_grad() logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)] loss = loss_pre + sum(loss_aux) if has_apex: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() torch.cuda.synchronize() lr_schdr.step() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux)] del im del lb ## print training log message lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) print_log_msg( i, cfg.max_iter, lr, time_meter, loss_meter, loss_pre_meter, loss_aux_meters) ##validation loop validation_loss = [] for it, (im, lb) in enumerate(Bar(valid)): net.eval() im = im.cuda() lb = lb.cuda() lb = torch.squeeze(lb, 1) with torch.no_grad(): logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)] loss = loss_pre + sum(loss_aux) validation_loss.append(loss.item()) del im del lb ## print training log messag validation_loss = sum(validation_loss)/len(validation_loss) print(f'Validation loss: {validation_loss}') if best_validation > validation_loss: print('new best performance, storing model') best_validation = validation_loss state = net.state_dict() torch.save(state, osp.join(cfg.respth, 'best_validation.pth')) ## dump the final model and evaluate the result save_pth = osp.join(cfg.respth, 'model_final.pth') logger.info('\nsave models to {}'.format(save_pth)) state = net.state_dict() torch.save(state, save_pth) logger.info('\nevaluating the final model') torch.cuda.empty_cache() heads, mious = eval_model(net, 2, cfg.im_root, cfg.test_im_anns) logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl')) return
def train(): logger = logging.getLogger() is_dist = dist.is_initialized() ## dataset dl = get_data_loader(cfg.im_root, cfg.train_im_anns, cfg.ims_per_gpu, cfg.scales, cfg.cropsize, cfg.max_iter, mode='train', distributed=is_dist) ## model net, criteria_pre, criteria_aux = set_model() ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## ddp training net = set_model_dist(net) ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() ## lr scheduler lr_schdr = WarmupPolyLrScheduler( optim, power=0.9, max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1, ) ##load checkpoin if exits for resuming training if args.loadCheckpointLocation != None: net, optim, lr_schdr, start_iteration = load_ckp( args.loadCheckpointLocation, net, optim, lr_schdr) else: start_iteration = 0 ## train loop for current_it, (im, lb) in enumerate(dl): #on resumed training 'it' will be incremented from what was left else the sum is 0 anyways it = current_it + start_iteration im = im.cuda() lb = lb.cuda() lb = torch.squeeze(lb, 1) optim.zero_grad() logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [ crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux) ] loss = loss_pre + sum(loss_aux) if has_apex: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() torch.cuda.synchronize() lr_schdr.step() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [ mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux) ] ## print training log message if (it + 1) % 100 == 0: lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) print_log_msg(it, cfg.max_iter, lr, time_meter, loss_meter, loss_pre_meter, loss_aux_meters) #save the checkpoint on every some iteration if (it + 1) % args.saveOnEveryIt == 0: if args.saveCheckpointDir != None: checkpoint = { 'iteration': it + 1, 'state_dict': net.state_dict(), 'optimizer': optim.state_dict(), 'lr_schdr': lr_schdr.state_dict(), } iteration_no_str = (str(it + 1)).zfill(len(str(cfg.max_iter))) ckt_name = 'checkpoint_it_' + iteration_no_str + '.pt' save_pth = osp.join(args.saveCheckpointDir, ckt_name) logger.info( '\nsaving intermidiate checkpoint to {}'.format(save_pth)) save_ckp(checkpoint, save_pth) ## dump the final model and evaluate the result checkpoint = { 'iteration': cfg.max_iter, 'state_dict': net.state_dict(), 'optimizer': optim.state_dict(), 'lr_schdr': lr_schdr.state_dict(), } save_pth = osp.join(args.saveCheckpointDir, 'model_final.pt') logger.info('\nsave Final models to {}'.format(save_pth)) save_ckp(checkpoint, save_pth) logger.info('\nevaluating the final model') torch.cuda.empty_cache() heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns) logger.info(tabulate([ mious, ], headers=heads, tablefmt='orgtbl')) return
def train(loginfo): logger = logging.getLogger() # is_dist = dist.is_initialized() logger.info("config: \n{}".format([item for item in cfg.__dict__.items()])) # ## dataset # dl = get_data_loader( # cfg.train_img_root, cfg.train_img_anns, # cfg.imgs_per_gpu, cfg.scales, cfg.cropsize, # cfg.max_iter, mode='train', distributed=is_dist) # dl = get_data_loader( # cfg.train_img_root, cfg.train_img_anns, # cfg.imgs_per_gpu, cfg.scales, cfg.cropsize, # cfg.anns_ignore, cfg.max_iter, mode='train', distributed=False) dl = prepare_data_loader(cfg.train_img_root, cfg.train_img_anns, cfg.input_size, cfg.imgs_per_gpu, device_count, cfg.scales, cfg.cropsize, cfg.anns_ignore, mode='train', distributed=False) max_iter = cfg.max_epoch * len(dl.dataset) // (cfg.imgs_per_gpu * device_count) \ if device == 'cuda' else cfg.max_epoch * len(dl.dataset) // cfg.imgs_per_gpu progress_iter = len(dl.dataset) / (cfg.imgs_per_gpu * device_count) // 5 \ if device == 'cuda' else len(dl.dataset) / cfg.imgs_per_gpu // 5 ## model net, criteria_pre, criteria_aux = set_model() net.to(device) if device_count >= 2: net = nn.DataParallel(net) torch.backends.cudnn.benchmark = True torch.multiprocessing.set_sharing_strategy('file_system') ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## ddp training # net = set_model_dist(net) # #CHANGED: normal training # #FIXME: GETTING STARTED WITH DISTRIBUTED DATA PARALLEL # #https://pytorch.org/tutorials/intermediate/ddp_tutorial.html ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters( max_iter) ## lr scheduler lr_schdr = WarmupPolyLrScheduler( optim, power=0.9, max_iter=max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1, ) ## train loopx n_epoch = 0 n_iter = 0 best_valid_loss = np.inf while n_epoch < cfg.max_epoch: net.train() # for n_iter, (img, tar) in enumerate(dl): # for n_iter, (img, tar) in enumerate(tqdm(dl)): for (img, tar) in tqdm(dl, desc='train epoch {:d}/{:d}'.format( n_epoch + 1, cfg.max_epoch)): img = img.to(device) tar = tar.to(device) tar = torch.squeeze(tar, 1) optim.zero_grad() logits, *logits_aux = net(img) loss_pre = criteria_pre(logits, tar) loss_aux = [ crit(lgt, tar) for crit, lgt in zip(criteria_aux, logits_aux) ] loss = loss_pre + sum(loss_aux) if has_apex: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() torch.cuda.synchronize() lr_schdr.step() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [ mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux) ] ## print training log message # if (n_iter + 1) % 100 == 0: if (n_iter + 1) % progress_iter == 0: lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) print_log_msg(n_epoch, cfg.max_epoch, n_iter, max_iter, lr, time_meter, loss_meter, loss_pre_meter, loss_aux_meters) n_iter = n_iter + 1 #CHANGED: save weight with valid loss ## dump the final model and evaluate the result # save_pth = os.path.join(cfg.weight_path, 'model_final.pth') # logger.info('\nsave models to {}'.format(save_pth)) # state = net.module.state_dict() # if dist.get_rank() == 0: torch.save(state, save_pth) logger.info('vaildating the {} epoch model'.format(n_epoch + 1)) valid_loss = valid(net, criteria_pre, criteria_aux, n_epoch, cfg, logger) if valid_loss < best_valid_loss: # save_path = os.path.join(cfg.weight_path, # 'epoch{:d}_valid_loss_{:.4f}.pth'.format(n_epoch, valid_loss)) if not os.path.exists(cfg.weight_path): os.makedirs(cfg.weight_path) save_path = os.path.join( cfg.weight_path, 'model_bestValidLoss-{}.pth'.format(loginfo)) logger.info('save models to {}'.format(save_path)) torch.save(net.state_dict(), save_path) best_valid_loss = valid_loss # logger.info('\nevaluating the final model') logger.info('evaluating the {} epoch model'.format(n_epoch + 1)) torch.cuda.empty_cache() ## For reset cuda memory used by cache # heads, mious = eval_model(net, 2, cfg.val_img_root, cfg.val_img_anns, cfg.n_classes) # logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl')) # heads, mious, eious = eval_model(net, cfg, device_count, cfg.val_img_root, cfg.val_img_anns, cfg.n_classes, cfg.anns_ignore) heads, mious, eious = test_model(net, cfg, device_count, cfg.val_img_root, cfg.val_img_anns, cfg.n_classes, cfg.anns_ignore) logger.info('\n' + tabulate( [ mious, ], headers=heads, tablefmt='github', floatfmt=".8f")) logger.info('\n' + tabulate(np.array(eious).transpose(), headers=heads, tablefmt='github', floatfmt=".8f", showindex=True)) n_epoch = n_epoch + 1 heads, mious, eious = eval_model(net, cfg, device_count, cfg.val_img_root, cfg.val_img_anns, cfg.n_classes, cfg.anns_ignore) logger.info( '\n' + tabulate([ mious, ], headers=heads, tablefmt='github', floatfmt=".8f")) logger.info('\n' + tabulate(np.array(eious).transpose(), headers=heads, tablefmt='github', floatfmt=".8f", showindex=True)) return
def train(): logger = logging.getLogger() is_dist = dist.is_initialized() ## dataset dl = get_data_loader(cfg.im_root, cfg.train_im_anns, cfg.ims_per_gpu, cfg.scales, cfg.cropsize, cfg.max_iter, mode='train', distributed=is_dist) ## model net, criteria_pre, criteria_aux = set_model() if dist.get_rank() == 0: exp_name = "cityscapes_repl" wandb.init(project="bisenet", name="cityscapes_repl") wandb.watch(net) ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## ddp training net = set_model_dist(net) ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() ## lr scheduler lr_schdr = WarmupPolyLrScheduler( optim, power=0.9, max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1, ) ## train loop for it, (im, lb) in enumerate(dl): net.train() im = im.cuda() lb = lb.cuda() lb = torch.squeeze(lb, 1) optim.zero_grad() logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [ crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux) ] loss = loss_pre + sum(loss_aux) if has_apex: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() torch.cuda.synchronize() lr_schdr.step() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [ mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux) ] lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) ## print training log message if dist.get_rank() == 0: loss_avg = loss_meter.get()[0] wandb.log( { "lr": lr, "time": time_meter.get()[0], "loss": loss_avg, "loss_pre": loss_pre_meter.get()[0], **{ f"loss_aux_{el.name}": el.get()[0] for el in loss_aux_meters } }, commit=False) if (it + 1) % 100 == 0: print(it, ' - ', lr, ' - ', loss_avg) if (it + 1) % 2000 == 0: # dump the model and evaluate the result save_pth = osp.join(cfg.respth, f"{exp_name}_{it}.pth") state = net.module.state_dict() torch.save(state, save_pth) wandb.save(save_pth) if ((it + 1) % 2000 == 0): logger.info('\nevaluating the model') heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns, it) logger.info(tabulate([ mious, ], headers=heads, tablefmt='orgtbl')) if (dist.get_rank() == 0): wandb.log({k: v for k, v in zip(heads, mious)}, commit=False) if (dist.get_rank() == 0): wandb.log({"t": it}, step=it) return
def train(): logger = logging.getLogger() is_dist = dist.is_initialized() ## dataset dl = get_data_loader(cfg, mode='train', distributed=is_dist) ## model net, criteria_pre, criteria_aux = set_model() ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## ddp training net = set_model_dist(net) ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() ## lr scheduler lr_schdr = WarmupPolyLrScheduler(optim, power=0.9, max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1,) ## train loop for it, (im, lb) in enumerate(dl): im = im.cuda() lb = lb.cuda() lb = torch.squeeze(lb, 1) optim.zero_grad() logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)] loss = loss_pre + sum(loss_aux) if has_apex: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() torch.cuda.synchronize() lr_schdr.step() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux)] ## print training log message if (it + 1) % 100 == 0: lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) print_log_msg( it, cfg.max_iter, lr, time_meter, loss_meter, loss_pre_meter, loss_aux_meters) ## dump the final model and evaluate the result save_pth = osp.join(cfg.respth, 'model_final.pth') logger.info('\nsave models to {}'.format(save_pth)) state = net.module.state_dict() if dist.get_rank() == 0: torch.save(state, save_pth, _use_new_zipfile_serialization=False) logger.info('\nevaluating the final model') torch.cuda.empty_cache() heads, mious = eval_model(cfg, net.module) logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl')) return
def main(): if not osp.exists(cfg.respth): os.makedirs(cfg.respth) setup_logger('{}-train'.format('banet'), cfg.respth) best_prec1 = (-1) logger = logging.getLogger() ## model net, criteria = set_model() ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## lr scheduler lr_schdr = WarmupPolyLrScheduler( optim, power=0.9, max_iter=cfg.epoch * 371, warmup_iter=cfg.warmup_iters * 371, warmup_ratio=0.1, warmup='exp', last_epoch=-1, ) for epoch in range(cfg.start_epoch, args.epoch_to_train): lr_schdr, time_meter, loss_meter = train(epoch, optim, net, criteria, lr_schdr) if True: #if ((epoch+1)!=cfg.epoch): lr = lr_schdr.get_lr() print(lr) lr = sum(lr) / len(lr) loss_avg = print_log_msg(epoch, cfg.epoch, lr, time_meter, loss_meter) writer.add_scalar('loss', loss_avg, epoch + 1) if ((epoch + 1) == cfg.epoch) or ((epoch + 1) == args.epoch_to_train): #if ((epoch+1)%1==0) and ((epoch+1)>cfg.warmup_iters): torch.cuda.empty_cache() heads, mious, miou = eval_model(net, ims_per_gpu=2, im_root=cfg.im_root, im_anns=cfg.val_im_anns, it=epoch) filename = osp.join(cfg.respth, args.store_name) state = net.state_dict() save_checkpoint(state, False, filename=filename) #writer.add_scalar('mIOU',miou,epoch+1) with open('lr_record.txt', 'w') as m: print('lr to store', lr) m.seek(0) m.write((str(epoch + 1) + ' ')) m.write(str(lr)) m.truncate() m.close() with open('best_miou.txt', 'r+') as f: best_miou = f.read() #print(best_miou) best_miou = best_miou.replace('\n', ' ') x = best_miou.split(' ') while ('' in x): x.remove('') best_miou = eval(x[-1]) is_best = miou > best_miou if is_best: best_miou = miou print('Is best? : ', is_best) f.seek(0) f.write((str(epoch + 1) + ' ')) f.write(str(best_miou)) f.truncate() f.close() save_checkpoint(state, is_best, filename) print('Have Stored Checkpoint') #if((epoch+1)==cfg.epoch) or ((epoch+1)==args.epoch_to_train): state = net.state_dict() torch.cuda.empty_cache() #heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns,it=epoch) logger.info(tabulate([ mious, ], headers=heads, tablefmt='orgtbl')) save_checkpoint(state, False, filename) print('Have Saved Final Model') break
def train(): logger = logging.getLogger() ## dataset dl = get_data_loader(cfg.im_root, cfg.train_im_anns, cfg.ims_per_gpu, cfg.scales, cfg.cropsize, cfg.max_iter, mode='train', distributed=False) #send few training images to tensorboard addImage_Tensorboard(dl) #finding max epoch to train dataset_length = len(dl.dataset) print("Dataset length: ", dataset_length) batch_size = cfg.ims_per_gpu print("Batch Size: ", batch_size) iteration_per_epoch = int(dataset_length / batch_size) max_epoch = int(cfg.max_iter / iteration_per_epoch) print("Max_epoch: ", max_epoch) ## model net, criteria_pre, criteria_aux = set_model() ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() ## lr scheduler lr_schdr = WarmupPolyLrScheduler( optim, power=0.9, max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1, ) ##load checkpoin if exits for resuming training if args.loadCheckpointLocation != None: net, optim, lr_schdr, start_epoch = load_ckp( args.loadCheckpointLocation, net, optim, lr_schdr) else: start_epoch = 0 #send the model structure to tensorboard addGraph_Tensorboard(net, dl) ## train loop for current_epoch in range(max_epoch): #on resumed training 'epoch' will be incremented from what was left else the sum is 0 anyways epoch = start_epoch + current_epoch for it, (im, lb) in enumerate(dl): im = im.to(device) lb = lb.to(device) lb = torch.squeeze(lb, 1) optim.zero_grad() logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [ crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux) ] loss = loss_pre + sum(loss_aux) if has_apex: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() lr_schdr.step() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [ mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux) ] ## print training log message global_it = it + epoch * iteration_per_epoch if (global_it + 1) % 100 == 0: lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) #write important scalars to tensorboard addScalars_loss_Tensorboard(global_it, loss_meter) addScalars_lr_Tensorboard(global_it, lr) print_log_msg(global_it, cfg.max_iter, lr, time_meter, loss_meter, loss_pre_meter, loss_aux_meters) #save the checkpoint on every some epoch if (epoch + 1) % args.saveOnEveryEpoch == 0: if args.saveCheckpointDir != None: checkpoint = { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer': optim.state_dict(), 'lr_schdr': lr_schdr.state_dict(), } epoch_no_str = (str(epoch + 1)).zfill(len(str(cfg.max_iter))) ckt_name = 'checkpoint_epoch_' + epoch_no_str + '.pt' save_pth = osp.join(args.saveCheckpointDir, ckt_name) logger.info( '\nsaving intermidiate checkpoint to {}'.format(save_pth)) save_ckp(checkpoint, save_pth) #compute validation accuracy in terms of mious logger.info('\nevaluating the model after ' + str(epoch + 1) + ' epoches') heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns, cfg.cropsize) #set back to training mode addScalars_val_accuracy_Tensorboard(global_it, heads, mious) net.train() ## dump the final model and evaluate the result checkpoint = { 'epoch': max_epoch, 'state_dict': net.state_dict(), 'optimizer': optim.state_dict(), 'lr_schdr': lr_schdr.state_dict(), } save_pth = osp.join(args.saveCheckpointDir, 'model_final.pt') logger.info('\nsave Final models to {}'.format(save_pth)) save_ckp(checkpoint, save_pth) logger.info('\nevaluating the final model') torch.cuda.empty_cache() heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns, cfg.cropsize) logger.info(tabulate([ mious, ], headers=heads, tablefmt='orgtbl')) return