def train(self, train_loader): args = self.args logger = self.logger steps_per_epoch = len(train_loader) device = self.device self.aanet.train() if args.freeze_bn: def set_bn_eval(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.eval() self.aanet.apply(set_bn_eval) # Learning rate summary base_lr = self.optimizer.param_groups[0]['lr'] offset_lr = self.optimizer.param_groups[1]['lr'] self.train_writer.add_scalar('base_lr', base_lr, self.epoch + 1) self.train_writer.add_scalar('offset_lr', offset_lr, self.epoch + 1) last_print_time = time.time() for i, sample in enumerate(train_loader): left = sample['left'].to(device) # [B, 3, H, W] right = sample['right'].to(device) gt_disp = sample['disp'].to(device) # [B, H, W] mask = (gt_disp > 0) & (gt_disp < args.max_disp) if args.load_pseudo_gt: pseudo_gt_disp = sample['pseudo_disp'].to(device) pseudo_mask = (pseudo_gt_disp > 0) & ( pseudo_gt_disp < args.max_disp) & (~mask) # inverse mask if not mask.any(): continue pred_disp_pyramid = self.aanet( left, right) # list of H/12, H/6, H/3, H/2, H if args.highest_loss_only: pred_disp_pyramid = [ pred_disp_pyramid[-1] ] # only the last highest resolution output disp_loss = 0 pseudo_disp_loss = 0 pyramid_loss = [] pseudo_pyramid_loss = [] # Loss weights if len(pred_disp_pyramid) == 5: pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0, 1.0] # AANet elif len(pred_disp_pyramid) == 4: pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0] # AANet+ elif len(pred_disp_pyramid) == 3: pyramid_weight = [1.0, 1.0, 1.0] # 1 scale only elif len(pred_disp_pyramid) == 1: pyramid_weight = [1.0] # highest loss only else: raise NotImplementedError assert len(pyramid_weight) == len(pred_disp_pyramid) for k in range(len(pred_disp_pyramid)): pred_disp = pred_disp_pyramid[k] weight = pyramid_weight[k] if pred_disp.size(-1) != gt_disp.size(-1): pred_disp = pred_disp.unsqueeze(1) # [B, 1, H, W] pred_disp = F.interpolate( pred_disp, size=(gt_disp.size(-2), gt_disp.size(-1)), mode='bilinear') * (gt_disp.size(-1) / pred_disp.size(-1)) pred_disp = pred_disp.squeeze(1) # [B, H, W] curr_loss = F.smooth_l1_loss(pred_disp[mask], gt_disp[mask], reduction='mean') disp_loss += weight * curr_loss pyramid_loss.append(curr_loss) # Pseudo gt loss if args.load_pseudo_gt: pseudo_curr_loss = F.smooth_l1_loss( pred_disp[pseudo_mask], pseudo_gt_disp[pseudo_mask], reduction='mean') pseudo_disp_loss += weight * pseudo_curr_loss pseudo_pyramid_loss.append(pseudo_curr_loss) total_loss = disp_loss + pseudo_disp_loss self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() self.num_iter += 1 if self.num_iter % args.print_freq == 0: this_cycle = time.time() - last_print_time last_print_time += this_cycle logger.info( 'Epoch: [%3d/%3d] [%5d/%5d] time: %4.2fs disp_loss: %.3f' % (self.epoch + 1, args.max_epoch, i + 1, steps_per_epoch, this_cycle, disp_loss.item())) if self.num_iter % args.summary_freq == 0: img_summary = dict() img_summary['left'] = left img_summary['right'] = right img_summary['gt_disp'] = gt_disp if args.load_pseudo_gt: img_summary['pseudo_gt_disp'] = pseudo_gt_disp # Save pyramid disparity prediction for s in range(len(pred_disp_pyramid)): # Scale from low to high, reverse save_name = 'pred_disp' + str( len(pred_disp_pyramid) - s - 1) save_value = pred_disp_pyramid[s] img_summary[save_name] = save_value pred_disp = pred_disp_pyramid[-1] if pred_disp.size(-1) != gt_disp.size(-1): pred_disp = pred_disp.unsqueeze(1) # [B, 1, H, W] pred_disp = F.interpolate( pred_disp, size=(gt_disp.size(-2), gt_disp.size(-1)), mode='bilinear') * (gt_disp.size(-1) / pred_disp.size(-1)) pred_disp = pred_disp.squeeze(1) # [B, H, W] img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp) save_images(self.train_writer, 'train', img_summary, self.num_iter) epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean') self.train_writer.add_scalar('train/epe', epe.item(), self.num_iter) self.train_writer.add_scalar('train/disp_loss', disp_loss.item(), self.num_iter) self.train_writer.add_scalar('train/total_loss', total_loss.item(), self.num_iter) # Save loss of different scale for s in range(len(pyramid_loss)): save_name = 'train/loss' + str(len(pyramid_loss) - s - 1) save_value = pyramid_loss[s] self.train_writer.add_scalar(save_name, save_value, self.num_iter) d1 = d1_metric(pred_disp, gt_disp, mask) self.train_writer.add_scalar('train/d1', d1.item(), self.num_iter) thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0) thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0) thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0) self.train_writer.add_scalar('train/thres1', thres1.item(), self.num_iter) self.train_writer.add_scalar('train/thres2', thres2.item(), self.num_iter) self.train_writer.add_scalar('train/thres3', thres3.item(), self.num_iter) self.epoch += 1 # Always save the latest model for resuming training if args.no_validate: utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=-1, best_epe=self.best_epe, best_epoch=self.best_epoch, filename='aanet_latest.pth') # Save checkpoint of specific epoch if self.epoch % args.save_ckpt_freq == 0: model_dir = os.path.join(args.checkpoint_dir, 'models') utils.check_path(model_dir) utils.save_checkpoint(model_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=-1, best_epe=self.best_epe, best_epoch=self.best_epoch, save_optimizer=False)
def validate(self, val_loader): args = self.args logger = self.logger logger.info('=> Start validation...') if args.evaluate_only is True: if args.pretrained_aanet is not None: pretrained_aanet = args.pretrained_aanet else: model_name = 'aanet_best.pth' pretrained_aanet = os.path.join(args.checkpoint_dir, model_name) if not os.path.exists( pretrained_aanet): # KITTI without validation pretrained_aanet = pretrained_aanet.replace( model_name, 'aanet_latest.pth') logger.info('=> loading pretrained aanet: %s' % pretrained_aanet) utils.load_pretrained_net(self.aanet, pretrained_aanet, no_strict=True) self.aanet.eval() num_samples = len(val_loader) logger.info('=> %d samples found in the validation set' % num_samples) val_epe = 0 val_d1 = 0 val_thres1 = 0 val_thres2 = 0 val_thres3 = 0 val_count = 0 val_file = os.path.join(args.checkpoint_dir, 'val_results.txt') num_imgs = 0 valid_samples = 0 for i, sample in enumerate(val_loader): if i % 100 == 0: logger.info('=> Validating %d/%d' % (i, num_samples)) left = sample['left'].to(self.device) # [B, 3, H, W] right = sample['right'].to(self.device) gt_disp = sample['disp'].to(self.device) # [B, H, W] mask = (gt_disp > 0) & (gt_disp < args.max_disp) if not mask.any(): continue valid_samples += 1 num_imgs += gt_disp.size(0) with torch.no_grad(): pred_disp = self.aanet(left, right)[-1] # [B, H, W] if pred_disp.size(-1) < gt_disp.size(-1): pred_disp = pred_disp.unsqueeze(1) # [B, 1, H, W] pred_disp = F.interpolate( pred_disp, (gt_disp.size(-2), gt_disp.size(-1)), mode='bilinear') * (gt_disp.size(-1) / pred_disp.size(-1)) pred_disp = pred_disp.squeeze(1) # [B, H, W] epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean') d1 = d1_metric(pred_disp, gt_disp, mask) thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0) thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0) thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0) val_epe += epe.item() val_d1 += d1.item() val_thres1 += thres1.item() val_thres2 += thres2.item() val_thres3 += thres3.item() # Save 3 images for visualization if not args.evaluate_only: if i in [ num_samples // 4, num_samples // 2, num_samples // 4 * 3 ]: img_summary = dict() img_summary['disp_error'] = disp_error_img( pred_disp, gt_disp) img_summary['left'] = left img_summary['right'] = right img_summary['gt_disp'] = gt_disp img_summary['pred_disp'] = pred_disp save_images(self.train_writer, 'val' + str(val_count), img_summary, self.epoch) val_count += 1 logger.info('=> Validation done!') mean_epe = val_epe / valid_samples mean_d1 = val_d1 / valid_samples mean_thres1 = val_thres1 / valid_samples mean_thres2 = val_thres2 / valid_samples mean_thres3 = val_thres3 / valid_samples # Save validation results with open(val_file, 'a') as f: f.write('epoch: %03d\t' % self.epoch) f.write('epe: %.3f\t' % mean_epe) f.write('d1: %.4f\t' % mean_d1) f.write('thres1: %.4f\t' % mean_thres1) f.write('thres2: %.4f\t' % mean_thres2) f.write('thres3: %.4f\n' % mean_thres3) logger.info('=> Mean validation epe of epoch %d: %.3f' % (self.epoch, mean_epe)) if not args.evaluate_only: self.train_writer.add_scalar('val/epe', mean_epe, self.epoch) self.train_writer.add_scalar('val/d1', mean_d1, self.epoch) self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch) self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch) self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch) if not args.evaluate_only: if args.val_metric == 'd1': if mean_d1 < self.best_epe: # Actually best_epe here is d1 self.best_epe = mean_d1 self.best_epoch = self.epoch utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=mean_d1, best_epe=self.best_epe, best_epoch=self.best_epoch, filename='aanet_best.pth') elif args.val_metric == 'epe': if mean_epe < self.best_epe: self.best_epe = mean_epe self.best_epoch = self.epoch utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=mean_epe, best_epe=self.best_epe, best_epoch=self.best_epoch, filename='aanet_best.pth') else: raise NotImplementedError if self.epoch == args.max_epoch: # Save best validation results with open(val_file, 'a') as f: f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' % (self.best_epoch, args.val_metric, self.best_epe)) logger.info('=> best epoch: %03d \t best %s: %.3f\n' % (self.best_epoch, args.val_metric, self.best_epe)) # Always save the latest model for resuming training if not args.evaluate_only: utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=mean_epe, best_epe=self.best_epe, best_epoch=self.best_epoch, filename='aanet_latest.pth') # Save checkpoint of specific epochs if self.epoch % args.save_ckpt_freq == 0: model_dir = os.path.join(args.checkpoint_dir, 'models') utils.check_path(model_dir) utils.save_checkpoint(model_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=mean_epe, best_epe=self.best_epe, best_epoch=self.best_epoch, save_optimizer=False)
def train(self, train_loader, local_master, trainLoss_dict, trainLossKey): args = self.args logger = self.logger steps_per_epoch = len(train_loader) / args.accumulation_steps # len(train_loader)返回的是Batch的个数 device = self.device self.aanet.train() # 设置模型为训练模式! if args.freeze_bn: def set_bn_eval(m): classname = m.__class__.__name__ # 实例调用__class__属性时会指向该实例对应的类。.__class__将实例变量指向类,然后再去调用__name__类属性 if classname.find('BatchNorm') != -1: m.eval() self.aanet.apply( set_bn_eval) # apply(fn: Callable[Module, None]):Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init). # Learning rate summary base_lr = self.optimizer.param_groups[0]['lr'] offset_lr = self.optimizer.param_groups[1]['lr'] self.train_writer.add_scalar('lr/base_lr', base_lr, self.epoch + 1) self.train_writer.add_scalar('lr/offset_lr', offset_lr, self.epoch + 1) last_print_time = time.time() validate_count = 0 total_epe = 0 total_d1 = 0 total_thres1 = 0 total_thres2 = 0 total_thres3 = 0 total_thres10 = 0 total_thres20 = 0 loss_acum = 0 for i, sample in enumerate(train_loader): left = sample['left'].to(device) # [B, 3, H, W] right = sample['right'].to(device) gt_disp = sample['disp'].to(device) # [B, H, W] mask = (gt_disp > 0) & (gt_disp < args.max_disp) # KITTI数据集约定:视差为0,表示无效视差。 if args.load_pseudo_gt: pseudo_gt_disp = sample['pseudo_disp'].to(device) pseudo_mask = (pseudo_gt_disp > 0) & (pseudo_gt_disp < args.max_disp) & ( ~mask) # inverse mask # 需要修补的像素位置的mask if not mask.any(): # np.array.any()是或操作,任意一个元素为True,输出为True。 continue # 尝试分布式训练 # 只在DDP模式下,轮数不是args.accumulation_steps整数倍的时候使用no_sync。 # 博客:https://blog.csdn.net/a40850273/article/details/111829836 my_context = self.aanet.no_sync if args.distributed and ( i + 1) % args.accumulation_steps != 0 else nullcontext with my_context(): pred_disp_pyramid = self.aanet(left, right) # list of H/12, H/6, H/3, H/2, H if args.highest_loss_only: pred_disp_pyramid = [pred_disp_pyramid[-1]] # only the last highest resolution output disp_loss = 0 pseudo_disp_loss = 0 pyramid_loss = [] pseudo_pyramid_loss = [] # Loss weights if len(pred_disp_pyramid) == 5: pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0, 1.0] # AANet and AANet+ elif len(pred_disp_pyramid) == 4: pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0] elif len(pred_disp_pyramid) == 3: pyramid_weight = [1.0, 1.0, 1.0] # 1 scale only elif len(pred_disp_pyramid) == 1: pyramid_weight = [1.0] # highest loss only else: raise NotImplementedError assert len(pyramid_weight) == len(pred_disp_pyramid) for k in range(len(pred_disp_pyramid)): pred_disp = pred_disp_pyramid[k] weight = pyramid_weight[k] if pred_disp.size(-1) != gt_disp.size(-1): pred_disp = pred_disp.unsqueeze(1) # [B, 1, H, W] pred_disp = F.interpolate(pred_disp, size=(gt_disp.size(-2), gt_disp.size(-1)), mode='bilinear', align_corners=False) * ( gt_disp.size(-1) / pred_disp.size(-1)) # 最后乘上这一项是必须的。因为图像放大,视差要相应增大。 pred_disp = pred_disp.squeeze(1) # [B, H, W] curr_loss = F.smooth_l1_loss(pred_disp[mask], gt_disp[mask], reduction='mean') disp_loss += weight * curr_loss pyramid_loss.append(curr_loss) # Pseudo gt loss if args.load_pseudo_gt: pseudo_curr_loss = F.smooth_l1_loss(pred_disp[pseudo_mask], pseudo_gt_disp[pseudo_mask], reduction='mean') pseudo_disp_loss += weight * pseudo_curr_loss pseudo_pyramid_loss.append(pseudo_curr_loss) total_loss = disp_loss + pseudo_disp_loss total_loss /= args.accumulation_steps total_loss.backward() # 仅用于记录和分析数据 with torch.no_grad(): validate_count += 1 total_epe += F.l1_loss(gt_disp[mask], pred_disp_pyramid[-1][mask], reduction='mean').detach().cpu().numpy() total_d1 += d1_metric(pred_disp_pyramid[-1], gt_disp, mask).detach().cpu().numpy() total_thres1 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 1.0).detach().cpu().numpy() # mask.shape=[B, H, W] total_thres2 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 2.0).detach().cpu().numpy() total_thres3 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 3.0).detach().cpu().numpy() total_thres10 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 10.0).detach().cpu().numpy() total_thres20 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 20.0).detach().cpu().numpy() loss_acum += total_loss.detach().cpu().numpy() if (i + 1) % args.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.num_iter += 1 if self.num_iter % args.print_freq == 0: this_cycle = time.time() - last_print_time last_print_time += this_cycle time_to_finish = (args.max_epoch - self.epoch) * (1.0 * steps_per_epoch / args.print_freq) * \ this_cycle / 3600.0 # 还有多久才能完成训练。单位:小时 logger.info('Epoch: [%3d/%3d] [%5d/%5d] time: %4.2fs remainT: %4.2fh disp_loss: %.3f' % (self.epoch + 1, args.max_epoch, self.num_iter, steps_per_epoch, this_cycle, time_to_finish, disp_loss.item())) # self.num_iter:表示当前一共进行了多少次迭代,一次参数更新表示一次迭代。 # steps_per_epoch:表示一个epoch中有多少次迭代,一次参数更新表示一次迭代。 if self.num_iter % args.summary_freq == 0: img_summary = dict() img_summary['left'] = left # [B, C=3, H, W] img_summary['right'] = right # [B, C=3, H, W] img_summary['gt_disp'] = gt_disp # [B, H, W] if args.load_pseudo_gt: img_summary['pseudo_gt_disp'] = pseudo_gt_disp # Save pyramid disparity prediction for s in range(len(pred_disp_pyramid)): # Scale from low to high, reverse save_name = 'pred_disp' + str( len(pred_disp_pyramid) - s - 1) # pred_disp0-->pred_disp4:高分辨率->低分辨率 save_value = pred_disp_pyramid[s] img_summary[save_name] = save_value pred_disp = pred_disp_pyramid[-1] if pred_disp.size(-1) != gt_disp.size(-1): pred_disp = pred_disp.unsqueeze(1) # [B, 1, H, W] pred_disp = F.interpolate(pred_disp, size=(gt_disp.size(-2), gt_disp.size(-1)), mode='bilinear', align_corners=False) * ( gt_disp.size(-1) / pred_disp.size(-1)) pred_disp = pred_disp.squeeze(1) # [B, H, W] img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp) # [B, C=3, H, W] save_images(self.train_writer, 'train', img_summary, self.num_iter) epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean') self.train_writer.add_scalar('train/epe', epe.item(), self.num_iter) self.train_writer.add_scalar('train/disp_loss', disp_loss.item(), self.num_iter) self.train_writer.add_scalar('train/total_loss', total_loss.item(), self.num_iter) # Save loss of different scale for s in range(len(pyramid_loss)): save_name = 'train/loss' + str(len(pyramid_loss) - s - 1) # loss0-->loss4:低分辨率~高分辨率 save_value = pyramid_loss[s] self.train_writer.add_scalar(save_name, save_value, self.num_iter) d1 = d1_metric(pred_disp, gt_disp, mask) # pred_disp.shape=[B, H, W], gt_disp.shape=[B, H, W] self.train_writer.add_scalar('train/d1', d1.item(), self.num_iter) thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0) # mask.shape=[B, H, W] thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0) thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0) thres10 = thres_metric(pred_disp, gt_disp, mask, 10.0) thres20 = thres_metric(pred_disp, gt_disp, mask, 20.0) self.train_writer.add_scalar('train/thres1', thres1.item(), self.num_iter) self.train_writer.add_scalar('train/thres2', thres2.item(), self.num_iter) self.train_writer.add_scalar('train/thres3', thres3.item(), self.num_iter) self.train_writer.add_scalar('train/thres10', thres10.item(), self.num_iter) self.train_writer.add_scalar('train/thres20', thres20.item(), self.num_iter) self.epoch += 1 # 记录数据为matlab的mat文件,用于分析和对比 trainLoss_dict[trainLossKey]['epochs'].append(self.epoch) trainLoss_dict[trainLossKey]['avgEPE'].append(total_epe / validate_count) trainLoss_dict[trainLossKey]['avg_d1'].append(total_d1 / validate_count) trainLoss_dict[trainLossKey]['avg_thres1'].append(total_thres1 / validate_count) trainLoss_dict[trainLossKey]['avg_thres2'].append(total_thres2 / validate_count) trainLoss_dict[trainLossKey]['avg_thres3'].append(total_thres3 / validate_count) trainLoss_dict[trainLossKey]['avg_thres10'].append(total_thres10 / validate_count) trainLoss_dict[trainLossKey]['avg_thres20'].append(total_thres20 / validate_count) trainLoss_dict[trainLossKey]['avg_loss'].append(loss_acum / validate_count) # 一个epoch结束: # args.no_validate=False,则后面不会做self.validate(),故需要在此处记录如下信息。 # args.no_validate=True,则后面会做self.validate(),会在elf.validate()中记录如下信息,此处不必记录。 # 需记录的信息包括: # 1.最新的训练的模型和状态(写入aanet_latest.pth、optimizer_latest.pth文件) for resuming training; # 2.Save checkpoint of specific epoch. # Always save the latest model for resuming training if args.no_validate: utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=-1, best_epe=self.best_epe, best_epoch=self.best_epoch, filename='aanet_latest.pth') if local_master else None # Save checkpoint of specific epoch if self.epoch % args.save_ckpt_freq == 0: model_dir = os.path.join(args.checkpoint_dir, 'models') utils.check_path(model_dir) utils.save_checkpoint(model_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=-1, best_epe=self.best_epe, best_epoch=self.best_epoch, save_optimizer=False) if local_master else None
def test(dataloader, model, log): stages = 1#len(args.loss_weights) EPEs = [AverageMeter() for _ in range(stages)] thres1 = [AverageMeter() for _ in range(stages)] length_loader = len(dataloader) model.eval() padding_len = 16 inference_time = 0 for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader): imgL = imgL.float().cuda() imgR = imgR.float().cuda() disp_L = disp_L.float().cuda() mask = disp_L < args.maxdisp if imgL.shape[2] % padding_len != 0: times = imgL.shape[2]//padding_len #print("times:", times) if times % 2 == 0: top_pad = (times + 2) * padding_len - imgL.shape[2] else: top_pad = (times+1)*padding_len -imgL.shape[2] else: top_pad = 0 right_pad = 0 imgL = F.pad(imgL,(0,right_pad, top_pad,0)) imgR = F.pad(imgR,(0,right_pad, top_pad,0)) #print("After padding imgL size:", imgL.shape) with torch.no_grad(): time_start = time.perf_counter() outputs = model(imgL, imgR) single_inference_time = time.perf_counter() - time_start #print("single_inference_time:", single_inference_time ) inference_time += single_inference_time #print("output size:", outputs[0].shape) for x in range(stages): if len(disp_L[mask]) == 0: EPEs[x].update(0) thres1[x].update(0) continue output = torch.squeeze(outputs[x], 1) if top_pad != 0: output = output[:, top_pad:, :] else: output = output if args.visualization and batch_idx <= 50: #vis # print("outputs", output.shape ) # print("disp_L size:", disp_L.shape) # GT = disp_L[:, :, :]/255.0 # torchvision.utils.save_image(GT, join(args.save_path, "iter-%d.jpg" % batch_idx)) #output = output.cpu() _, H, W = output.shape # all_results_color = torch.zeros((H, 2*W)) GT_color = torch.zeros((H, W )) GT_color[:, :W] = disp_L[0,:, :] GT_color = cv.applyColorMap(np.array(GT_color * 2, dtype=np.uint8), cv.COLORMAP_JET) cv.imwrite(join(args.save_path, "iter-%d_GT_color.jpg" % batch_idx), GT_color) pred_color = torch.zeros((H, W )) pred_color[:, :W] = output[0, :, :] # all_results_color = torch.zeros((H, 2 * W + 20)) # # # all_results_color[:,:W]= output[0, :, :] # # #all_results_color[:, W:30] = output[0, :, :] # # all_results_color[:,W+20:2*W+20]= disp_L[0,:, :] pred_color = cv.applyColorMap(np.array(pred_color*2, dtype=np.uint8), cv.COLORMAP_JET) error = (output[mask] - disp_L[mask]).abs().mean() cv.imwrite(join(args.save_path, "iter-%d_pred_color-%.3f.jpg" %( batch_idx , error)),pred_color) EPEs[x].update((output[mask] - disp_L[mask]).abs().mean()) thres1[x].update(thres_metric(output, disp_L, mask, 1.0)) if not batch_idx % args.print_freq: print("single_inference_time:", single_inference_time) info_str = '\t'.join(['Stage {} = {:.2f}({:.2f})'.format(x, EPEs[x].val, EPEs[x].avg) for x in range(stages)]) log.info('EPEs [{}/{}] {}'.format(batch_idx, length_loader, info_str)) info_str = '\t'.join(['Stage {} = {:.3f}({:.3f})'.format(x, thres1[x].val, thres1[x].avg) for x in range(stages)]) log.info('thres1 [{}/{}] {}'.format(batch_idx, length_loader, info_str)) log.info(('=> Mean inference time for %d images: %.3fs' % ( length_loader, inference_time / length_loader))) info_str = ', '.join(['Stage {}={:.2f}'.format(x, EPEs[x].avg) for x in range(stages)]) log.info('Average test EPE = ' + info_str) info_str = ', '.join(['Stage {}={:.3f}'.format(x, thres1[x].avg) for x in range(stages)]) log.info('Average test thres1 = ' + info_str)
def validate(self, val_loader, local_master, valLossDict, valLossKey): args = self.args logger = self.logger logger.info('=> Start validation...') # 只做evaluate,则需要从文件加载训练好的模型。否则,直接使用本model类中保存的(尚未完成全部的Epoach训练的)self.aanet即可。 if args.evaluate_only is True: if args.pretrained_aanet is not None: pretrained_aanet = args.pretrained_aanet else: model_name = 'aanet_best.pth' pretrained_aanet = os.path.join(args.checkpoint_dir, model_name) if not os.path.exists(pretrained_aanet): # KITTI without validation pretrained_aanet = pretrained_aanet.replace(model_name, 'aanet_latest.pth') logger.info('=> loading pretrained aanet: %s' % pretrained_aanet) utils.load_pretrained_net(self.aanet, pretrained_aanet, no_strict=True) self.aanet.eval() num_samples = len(val_loader) logger.info('=> %d samples found in the validation set' % num_samples) val_epe = 0 val_d1 = 0 val_thres1 = 0 val_thres2 = 0 val_thres3 = 0 val_thres10 = 0 val_thres20 = 0 val_count = 0 val_file = os.path.join(args.checkpoint_dir, 'val_results.txt') num_imgs = 0 valid_samples = 0 # 遍历验证样本或测试样本 for i, sample in enumerate(val_loader): if (i + 1) % 100 == 0: logger.info('=> Validating %d/%d' % (i, num_samples)) left = sample['left'].to(self.device) # [B, 3, H, W] right = sample['right'].to(self.device) gt_disp = sample['disp'].to(self.device) # [B, H, W] mask = (gt_disp > 0) & (gt_disp < args.max_disp) if not mask.any(): continue valid_samples += 1 num_imgs += gt_disp.size(0) with torch.no_grad(): disparity_pyramid = self.aanet(left, right) # [B, H, W] pred_disp = disparity_pyramid[-1] if pred_disp.size(-1) < gt_disp.size(-1): pred_disp = pred_disp.unsqueeze(1) # [B, 1, H, W] pred_disp = F.interpolate(pred_disp, (gt_disp.size(-2), gt_disp.size(-1)), mode='bilinear', align_corners=False) * ( gt_disp.size(-1) / pred_disp.size(-1)) pred_disp = pred_disp.squeeze(1) # [B, H, W] epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean') d1 = d1_metric(pred_disp, gt_disp, mask) thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0) thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0) thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0) thres10 = thres_metric(pred_disp, gt_disp, mask, 10.0) thres20 = thres_metric(pred_disp, gt_disp, mask, 20.0) val_epe += epe.item() val_d1 += d1.item() val_thres1 += thres1.item() val_thres2 += thres2.item() val_thres3 += thres3.item() val_thres10 += thres10.item() val_thres20 += thres20.item() # save Image For Error Analysis # saveForErrorAnalysis(index, img_name, dstPath, dstName, left, right, gt_disp, disparity_pyramid): with torch.no_grad(): saveImgErrorAnalysis(i, sample['left_name'], './myDataAnalysis', 'SceneFlow_valIdx_{}'.format(i), left, right, gt_disp, disparity_pyramid, disp_error_img(pred_disp, gt_disp)) # Save 3 images for visualization if not args.evaluate_only or args.mode == 'test': # if i in [num_samples // 4, num_samples // 2, num_samples // 4 * 3]: if i in [num_samples // 6, num_samples // 6 * 2, num_samples // 6 * 3, num_samples // 6 * 4, num_samples // 6 * 5]: img_summary = dict() img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp) img_summary['left'] = left img_summary['right'] = right img_summary['gt_disp'] = gt_disp img_summary['pred_disp'] = pred_disp save_images(self.train_writer, 'val' + str(val_count), img_summary, self.epoch) disp_error = disp_error_hist(pred_disp, gt_disp, args.max_disp) save_hist(self.train_writer, '{}/{}'.format('val' + str(val_count), 'hist'), disp_error, self.epoch) val_count += 1 # 遍历验证样本或测试样本完成 logger.info('=> Validation done!') mean_epe = val_epe / valid_samples mean_d1 = val_d1 / valid_samples mean_thres1 = val_thres1 / valid_samples mean_thres2 = val_thres2 / valid_samples mean_thres3 = val_thres3 / valid_samples mean_thres10 = val_thres10 / valid_samples mean_thres20 = val_thres20 / valid_samples # 记录数据为matlab的mat文件,用于分析和对比 valLossDict[valLossKey]["epochs"].append(self.epoch) valLossDict[valLossKey]["avgEPE"].append(mean_epe) valLossDict[valLossKey]["avg_d1"].append(mean_d1) valLossDict[valLossKey]["avg_thres1"].append(mean_thres1) valLossDict[valLossKey]["avg_thres2"].append(mean_thres2) valLossDict[valLossKey]["avg_thres3"].append(mean_thres3) valLossDict[valLossKey]["avg_thres10"].append(mean_thres10) valLossDict[valLossKey]["avg_thres20"].append(mean_thres20) # Save validation results with open(val_file, 'a') as f: f.write('epoch: %03d\t' % self.epoch) f.write('epe: %.3f\t' % mean_epe) f.write('d1: %.4f\t' % mean_d1) f.write('thres1: %.4f\t' % mean_thres1) f.write('thres2: %.4f\t' % mean_thres2) f.write('thres3: %.4f\t' % mean_thres3) f.write('thres10: %.4f\t' % mean_thres10) f.write('thres20: %.4f\n' % mean_thres20) f.write('dataset_name= %s\t mode=%s\n' % (args.dataset_name, args.mode)) logger.info('=> Mean validation epe of epoch %d: %.3f' % (self.epoch, mean_epe)) if not args.evaluate_only: self.train_writer.add_scalar('val/epe', mean_epe, self.epoch) self.train_writer.add_scalar('val/d1', mean_d1, self.epoch) self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch) self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch) self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch) self.train_writer.add_scalar('val/thres10', mean_thres10, self.epoch) self.train_writer.add_scalar('val/thres20', mean_thres20, self.epoch) if not args.evaluate_only: if args.val_metric == 'd1': if mean_d1 < self.best_epe: # Actually best_epe here is d1 self.best_epe = mean_d1 self.best_epoch = self.epoch utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=mean_d1, best_epe=self.best_epe, best_epoch=self.best_epoch, filename='aanet_best.pth') if local_master else None elif args.val_metric == 'epe': if mean_epe < self.best_epe: self.best_epe = mean_epe self.best_epoch = self.epoch utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=mean_epe, best_epe=self.best_epe, best_epoch=self.best_epoch, filename='aanet_best.pth') if local_master else None else: raise NotImplementedError if self.epoch == args.max_epoch: # Save best validation results with open(val_file, 'a') as f: f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' % (self.best_epoch, args.val_metric, self.best_epe)) logger.info('=> best epoch: %03d \t best %s: %.3f\n' % (self.best_epoch, args.val_metric, self.best_epe)) # Always save the latest model for resuming training if not args.evaluate_only: utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=mean_epe, best_epe=self.best_epe, best_epoch=self.best_epoch, filename='aanet_latest.pth') if local_master else None # Save checkpoint of specific epochs if self.epoch % args.save_ckpt_freq == 0: model_dir = os.path.join(args.checkpoint_dir, 'models') utils.check_path(model_dir) utils.save_checkpoint(model_dir, self.optimizer, self.aanet, epoch=self.epoch, num_iter=self.num_iter, epe=mean_epe, best_epe=self.best_epe, best_epoch=self.best_epoch, save_optimizer=False) if local_master else None