def train(train_loader, net, criterion, optim, curr_epoch, scheduler, max_iter): """ Runs the training loop per epoch train_loader: Data loader for train net: thet network optimizer: optimizer curr_epoch: current epoch writer: tensorboard writer return: """ net.train() train_total_loss = AverageMeter() time_meter = AverageMeter() curr_iter = curr_epoch * len(train_loader) for i, data in enumerate(train_loader): if curr_iter >= max_iter: break start_ts = time.time() inputs, gts = data batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3) inputs, gts = inputs.cuda(), gts.cuda() optim.zero_grad() outputs = net(inputs) total_loss = criterion(outputs, gts) log_total_loss = total_loss.clone().detach_() train_total_loss.update(log_total_loss.item(), batch_pixel_size) total_loss.backward() optim.step() scheduler.step() time_meter.update(time.time() - start_ts) del total_loss curr_iter += 1 if i % 50 == 49: msg = '[epoch {}], [iter {} / {} : {}], [loss {:0.6f}], [lr {:0.6f}], [time {:0.4f}]'.format( curr_epoch, i + 1, len(train_loader), curr_iter, train_total_loss.avg, optim.param_groups[-1]['lr'], time_meter.avg / args.batch_size) logging.info(msg) train_total_loss.reset() time_meter.reset() return curr_iter
def train(train_loader, net, optim, curr_epoch, writer, scheduler, max_iter): """ Runs the training loop per epoch train_loader: Data loader for train net: thet network optimizer: optimizer curr_epoch: current epoch writer: tensorboard writer return: """ net.train() train_total_loss = AverageMeter() time_meter = AverageMeter() curr_iter = curr_epoch * len(train_loader) for i, data in enumerate(train_loader): if curr_iter >= max_iter: break inputs, gts, _, aux_gts = data # Multi source and AGG case if len(inputs.shape) == 5: B, D, C, H, W = inputs.shape num_domains = D inputs = inputs.transpose(0, 1) gts = gts.transpose(0, 1).squeeze(2) aux_gts = aux_gts.transpose(0, 1).squeeze(2) inputs = [ input.squeeze(0) for input in torch.chunk(inputs, num_domains, 0) ] gts = [gt.squeeze(0) for gt in torch.chunk(gts, num_domains, 0)] aux_gts = [ aux_gt.squeeze(0) for aux_gt in torch.chunk(aux_gts, num_domains, 0) ] else: B, C, H, W = inputs.shape num_domains = 1 inputs = [inputs] gts = [gts] aux_gts = [aux_gts] batch_pixel_size = C * H * W for di, ingredients in enumerate(zip(inputs, gts, aux_gts)): input, gt, aux_gt = ingredients start_ts = time.time() img_gt = None input, gt = input.cuda(), gt.cuda() optim.zero_grad() if args.use_isw: outputs = net(input, gts=gt, aux_gts=aux_gt, img_gt=img_gt, visualize=args.visualize_feature, apply_wtloss=False if curr_epoch <= args.cov_stat_epoch else True) else: outputs = net(input, gts=gt, aux_gts=aux_gt, img_gt=img_gt, visualize=args.visualize_feature) outputs_index = 0 main_loss = outputs[outputs_index] outputs_index += 1 aux_loss = outputs[outputs_index] outputs_index += 1 total_loss = main_loss + (0.4 * aux_loss) if args.use_wtloss and (not args.use_isw or (args.use_isw and curr_epoch > args.cov_stat_epoch)): wt_loss = outputs[outputs_index] outputs_index += 1 total_loss = total_loss + (args.wt_reg_weight * wt_loss) else: wt_loss = 0 if args.visualize_feature: f_cor_arr = outputs[outputs_index] outputs_index += 1 log_total_loss = total_loss.clone().detach_() torch.distributed.all_reduce(log_total_loss, torch.distributed.ReduceOp.SUM) log_total_loss = log_total_loss / args.world_size train_total_loss.update(log_total_loss.item(), batch_pixel_size) total_loss.backward() optim.step() time_meter.update(time.time() - start_ts) del total_loss, log_total_loss if args.local_rank == 0: if i % 50 == 49: if args.visualize_feature: visualize_matrix(writer, f_cor_arr, curr_iter, '/Covariance/Feature-') msg = '[epoch {}], [iter {} / {} : {}], [loss {:0.6f}], [lr {:0.6f}], [time {:0.4f}]'.format( curr_epoch, i + 1, len(train_loader), curr_iter, train_total_loss.avg, optim.param_groups[-1]['lr'], time_meter.avg / args.train_batch_size) logging.info(msg) if args.use_wtloss: print("Whitening Loss", wt_loss) # Log tensorboard metrics for each iteration of the training phase writer.add_scalar('loss/train_loss', (train_total_loss.avg), curr_iter) train_total_loss.reset() time_meter.reset() curr_iter += 1 scheduler.step() if i > 5 and args.test_mode: return curr_iter return curr_iter