def train(): logger = logging.getLogger() is_dist = dist.is_initialized() ## dataset dl = get_data_loader(cfg.im_root, cfg.train_im_anns, cfg.ims_per_gpu, cfg.scales, cfg.cropsize, cfg.max_iter, mode='train', distributed=is_dist) ## model net, criteria_pre, criteria_aux = set_model() ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## ddp training net = set_model_dist(net) ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() ## lr scheduler lr_schdr = WarmupPolyLrScheduler( optim, power=0.9, max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1, ) ##load checkpoin if exits for resuming training if args.loadCheckpointLocation != None: net, optim, lr_schdr, start_iteration = load_ckp( args.loadCheckpointLocation, net, optim, lr_schdr) else: start_iteration = 0 ## train loop for current_it, (im, lb) in enumerate(dl): #on resumed training 'it' will be incremented from what was left else the sum is 0 anyways it = current_it + start_iteration im = im.cuda() lb = lb.cuda() lb = torch.squeeze(lb, 1) optim.zero_grad() logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [ crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux) ] loss = loss_pre + sum(loss_aux) if has_apex: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() torch.cuda.synchronize() lr_schdr.step() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [ mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux) ] ## print training log message if (it + 1) % 100 == 0: lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) print_log_msg(it, cfg.max_iter, lr, time_meter, loss_meter, loss_pre_meter, loss_aux_meters) #save the checkpoint on every some iteration if (it + 1) % args.saveOnEveryIt == 0: if args.saveCheckpointDir != None: checkpoint = { 'iteration': it + 1, 'state_dict': net.state_dict(), 'optimizer': optim.state_dict(), 'lr_schdr': lr_schdr.state_dict(), } iteration_no_str = (str(it + 1)).zfill(len(str(cfg.max_iter))) ckt_name = 'checkpoint_it_' + iteration_no_str + '.pt' save_pth = osp.join(args.saveCheckpointDir, ckt_name) logger.info( '\nsaving intermidiate checkpoint to {}'.format(save_pth)) save_ckp(checkpoint, save_pth) ## dump the final model and evaluate the result checkpoint = { 'iteration': cfg.max_iter, 'state_dict': net.state_dict(), 'optimizer': optim.state_dict(), 'lr_schdr': lr_schdr.state_dict(), } save_pth = osp.join(args.saveCheckpointDir, 'model_final.pt') logger.info('\nsave Final models to {}'.format(save_pth)) save_ckp(checkpoint, save_pth) logger.info('\nevaluating the final model') torch.cuda.empty_cache() heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns) logger.info(tabulate([ mious, ], headers=heads, tablefmt='orgtbl')) return
def train(): logger = logging.getLogger() ## dataset dl = get_data_loader(cfg.im_root, cfg.train_im_anns, cfg.ims_per_gpu, cfg.scales, cfg.cropsize, cfg.max_iter, mode='train', distributed=False) #send few training images to tensorboard addImage_Tensorboard(dl) #finding max epoch to train dataset_length = len(dl.dataset) print("Dataset length: ", dataset_length) batch_size = cfg.ims_per_gpu print("Batch Size: ", batch_size) iteration_per_epoch = int(dataset_length / batch_size) max_epoch = int(cfg.max_iter / iteration_per_epoch) print("Max_epoch: ", max_epoch) ## model net, criteria_pre, criteria_aux = set_model() ## optimizer optim = set_optimizer(net) ## fp16 if has_apex: opt_level = 'O1' if cfg.use_fp16 else 'O0' net, optim = amp.initialize(net, optim, opt_level=opt_level) ## meters time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() ## lr scheduler lr_schdr = WarmupPolyLrScheduler( optim, power=0.9, max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, warmup_ratio=0.1, warmup='exp', last_epoch=-1, ) ##load checkpoin if exits for resuming training if args.loadCheckpointLocation != None: net, optim, lr_schdr, start_epoch = load_ckp( args.loadCheckpointLocation, net, optim, lr_schdr) else: start_epoch = 0 #send the model structure to tensorboard addGraph_Tensorboard(net, dl) ## train loop for current_epoch in range(max_epoch): #on resumed training 'epoch' will be incremented from what was left else the sum is 0 anyways epoch = start_epoch + current_epoch for it, (im, lb) in enumerate(dl): im = im.to(device) lb = lb.to(device) lb = torch.squeeze(lb, 1) optim.zero_grad() logits, *logits_aux = net(im) loss_pre = criteria_pre(logits, lb) loss_aux = [ crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux) ] loss = loss_pre + sum(loss_aux) if has_apex: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() lr_schdr.step() time_meter.update() loss_meter.update(loss.item()) loss_pre_meter.update(loss_pre.item()) _ = [ mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux) ] ## print training log message global_it = it + epoch * iteration_per_epoch if (global_it + 1) % 100 == 0: lr = lr_schdr.get_lr() lr = sum(lr) / len(lr) #write important scalars to tensorboard addScalars_loss_Tensorboard(global_it, loss_meter) addScalars_lr_Tensorboard(global_it, lr) print_log_msg(global_it, cfg.max_iter, lr, time_meter, loss_meter, loss_pre_meter, loss_aux_meters) #save the checkpoint on every some epoch if (epoch + 1) % args.saveOnEveryEpoch == 0: if args.saveCheckpointDir != None: checkpoint = { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer': optim.state_dict(), 'lr_schdr': lr_schdr.state_dict(), } epoch_no_str = (str(epoch + 1)).zfill(len(str(cfg.max_iter))) ckt_name = 'checkpoint_epoch_' + epoch_no_str + '.pt' save_pth = osp.join(args.saveCheckpointDir, ckt_name) logger.info( '\nsaving intermidiate checkpoint to {}'.format(save_pth)) save_ckp(checkpoint, save_pth) #compute validation accuracy in terms of mious logger.info('\nevaluating the model after ' + str(epoch + 1) + ' epoches') heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns, cfg.cropsize) #set back to training mode addScalars_val_accuracy_Tensorboard(global_it, heads, mious) net.train() ## dump the final model and evaluate the result checkpoint = { 'epoch': max_epoch, 'state_dict': net.state_dict(), 'optimizer': optim.state_dict(), 'lr_schdr': lr_schdr.state_dict(), } save_pth = osp.join(args.saveCheckpointDir, 'model_final.pt') logger.info('\nsave Final models to {}'.format(save_pth)) save_ckp(checkpoint, save_pth) logger.info('\nevaluating the final model') torch.cuda.empty_cache() heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns, cfg.cropsize) logger.info(tabulate([ mious, ], headers=heads, tablefmt='orgtbl')) return