예제 #1
0
    def train(self, train_loader):
        args = self.args
        logger = self.logger

        steps_per_epoch = len(train_loader)
        device = self.device

        self.aanet.train()

        if args.freeze_bn:

            def set_bn_eval(m):
                classname = m.__class__.__name__
                if classname.find('BatchNorm') != -1:
                    m.eval()

            self.aanet.apply(set_bn_eval)

        # Learning rate summary
        base_lr = self.optimizer.param_groups[0]['lr']
        offset_lr = self.optimizer.param_groups[1]['lr']
        self.train_writer.add_scalar('base_lr', base_lr, self.epoch + 1)
        self.train_writer.add_scalar('offset_lr', offset_lr, self.epoch + 1)

        last_print_time = time.time()

        for i, sample in enumerate(train_loader):
            left = sample['left'].to(device)  # [B, 3, H, W]
            right = sample['right'].to(device)
            gt_disp = sample['disp'].to(device)  # [B, H, W]

            mask = (gt_disp > 0) & (gt_disp < args.max_disp)

            if args.load_pseudo_gt:
                pseudo_gt_disp = sample['pseudo_disp'].to(device)
                pseudo_mask = (pseudo_gt_disp > 0) & (
                    pseudo_gt_disp < args.max_disp) & (~mask)  # inverse mask

            if not mask.any():
                continue

            pred_disp_pyramid = self.aanet(
                left, right)  # list of H/12, H/6, H/3, H/2, H

            if args.highest_loss_only:
                pred_disp_pyramid = [
                    pred_disp_pyramid[-1]
                ]  # only the last highest resolution output

            disp_loss = 0
            pseudo_disp_loss = 0
            pyramid_loss = []
            pseudo_pyramid_loss = []

            # Loss weights
            if len(pred_disp_pyramid) == 5:
                pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0, 1.0]  # AANet
            elif len(pred_disp_pyramid) == 4:
                pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0]  # AANet+
            elif len(pred_disp_pyramid) == 3:
                pyramid_weight = [1.0, 1.0, 1.0]  # 1 scale only
            elif len(pred_disp_pyramid) == 1:
                pyramid_weight = [1.0]  # highest loss only
            else:
                raise NotImplementedError

            assert len(pyramid_weight) == len(pred_disp_pyramid)
            for k in range(len(pred_disp_pyramid)):
                pred_disp = pred_disp_pyramid[k]
                weight = pyramid_weight[k]

                if pred_disp.size(-1) != gt_disp.size(-1):
                    pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                    pred_disp = F.interpolate(
                        pred_disp,
                        size=(gt_disp.size(-2), gt_disp.size(-1)),
                        mode='bilinear') * (gt_disp.size(-1) /
                                            pred_disp.size(-1))
                    pred_disp = pred_disp.squeeze(1)  # [B, H, W]

                curr_loss = F.smooth_l1_loss(pred_disp[mask],
                                             gt_disp[mask],
                                             reduction='mean')
                disp_loss += weight * curr_loss
                pyramid_loss.append(curr_loss)

                # Pseudo gt loss
                if args.load_pseudo_gt:
                    pseudo_curr_loss = F.smooth_l1_loss(
                        pred_disp[pseudo_mask],
                        pseudo_gt_disp[pseudo_mask],
                        reduction='mean')
                    pseudo_disp_loss += weight * pseudo_curr_loss

                    pseudo_pyramid_loss.append(pseudo_curr_loss)

            total_loss = disp_loss + pseudo_disp_loss

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            self.num_iter += 1

            if self.num_iter % args.print_freq == 0:
                this_cycle = time.time() - last_print_time
                last_print_time += this_cycle

                logger.info(
                    'Epoch: [%3d/%3d] [%5d/%5d] time: %4.2fs disp_loss: %.3f' %
                    (self.epoch + 1, args.max_epoch, i + 1, steps_per_epoch,
                     this_cycle, disp_loss.item()))

            if self.num_iter % args.summary_freq == 0:
                img_summary = dict()
                img_summary['left'] = left
                img_summary['right'] = right
                img_summary['gt_disp'] = gt_disp

                if args.load_pseudo_gt:
                    img_summary['pseudo_gt_disp'] = pseudo_gt_disp

                # Save pyramid disparity prediction
                for s in range(len(pred_disp_pyramid)):
                    # Scale from low to high, reverse
                    save_name = 'pred_disp' + str(
                        len(pred_disp_pyramid) - s - 1)
                    save_value = pred_disp_pyramid[s]
                    img_summary[save_name] = save_value

                pred_disp = pred_disp_pyramid[-1]

                if pred_disp.size(-1) != gt_disp.size(-1):
                    pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                    pred_disp = F.interpolate(
                        pred_disp,
                        size=(gt_disp.size(-2), gt_disp.size(-1)),
                        mode='bilinear') * (gt_disp.size(-1) /
                                            pred_disp.size(-1))
                    pred_disp = pred_disp.squeeze(1)  # [B, H, W]
                img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)

                save_images(self.train_writer, 'train', img_summary,
                            self.num_iter)

                epe = F.l1_loss(gt_disp[mask],
                                pred_disp[mask],
                                reduction='mean')

                self.train_writer.add_scalar('train/epe', epe.item(),
                                             self.num_iter)
                self.train_writer.add_scalar('train/disp_loss',
                                             disp_loss.item(), self.num_iter)
                self.train_writer.add_scalar('train/total_loss',
                                             total_loss.item(), self.num_iter)

                # Save loss of different scale
                for s in range(len(pyramid_loss)):
                    save_name = 'train/loss' + str(len(pyramid_loss) - s - 1)
                    save_value = pyramid_loss[s]
                    self.train_writer.add_scalar(save_name, save_value,
                                                 self.num_iter)

                d1 = d1_metric(pred_disp, gt_disp, mask)
                self.train_writer.add_scalar('train/d1', d1.item(),
                                             self.num_iter)
                thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
                thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
                thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)
                self.train_writer.add_scalar('train/thres1', thres1.item(),
                                             self.num_iter)
                self.train_writer.add_scalar('train/thres2', thres2.item(),
                                             self.num_iter)
                self.train_writer.add_scalar('train/thres3', thres3.item(),
                                             self.num_iter)

        self.epoch += 1

        # Always save the latest model for resuming training
        if args.no_validate:
            utils.save_checkpoint(args.checkpoint_dir,
                                  self.optimizer,
                                  self.aanet,
                                  epoch=self.epoch,
                                  num_iter=self.num_iter,
                                  epe=-1,
                                  best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth')

            # Save checkpoint of specific epoch
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir,
                                      self.optimizer,
                                      self.aanet,
                                      epoch=self.epoch,
                                      num_iter=self.num_iter,
                                      epe=-1,
                                      best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False)
예제 #2
0
    def validate(self, val_loader):
        args = self.args
        logger = self.logger
        logger.info('=> Start validation...')

        if args.evaluate_only is True:
            if args.pretrained_aanet is not None:
                pretrained_aanet = args.pretrained_aanet
            else:
                model_name = 'aanet_best.pth'
                pretrained_aanet = os.path.join(args.checkpoint_dir,
                                                model_name)
                if not os.path.exists(
                        pretrained_aanet):  # KITTI without validation
                    pretrained_aanet = pretrained_aanet.replace(
                        model_name, 'aanet_latest.pth')

            logger.info('=> loading pretrained aanet: %s' % pretrained_aanet)
            utils.load_pretrained_net(self.aanet,
                                      pretrained_aanet,
                                      no_strict=True)

        self.aanet.eval()

        num_samples = len(val_loader)
        logger.info('=> %d samples found in the validation set' % num_samples)

        val_epe = 0
        val_d1 = 0
        val_thres1 = 0
        val_thres2 = 0
        val_thres3 = 0

        val_count = 0

        val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')

        num_imgs = 0
        valid_samples = 0

        for i, sample in enumerate(val_loader):
            if i % 100 == 0:
                logger.info('=> Validating %d/%d' % (i, num_samples))

            left = sample['left'].to(self.device)  # [B, 3, H, W]
            right = sample['right'].to(self.device)
            gt_disp = sample['disp'].to(self.device)  # [B, H, W]
            mask = (gt_disp > 0) & (gt_disp < args.max_disp)

            if not mask.any():
                continue

            valid_samples += 1

            num_imgs += gt_disp.size(0)

            with torch.no_grad():
                pred_disp = self.aanet(left, right)[-1]  # [B, H, W]

            if pred_disp.size(-1) < gt_disp.size(-1):
                pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                pred_disp = F.interpolate(
                    pred_disp, (gt_disp.size(-2), gt_disp.size(-1)),
                    mode='bilinear') * (gt_disp.size(-1) / pred_disp.size(-1))
                pred_disp = pred_disp.squeeze(1)  # [B, H, W]

            epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean')
            d1 = d1_metric(pred_disp, gt_disp, mask)
            thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
            thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
            thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)

            val_epe += epe.item()
            val_d1 += d1.item()
            val_thres1 += thres1.item()
            val_thres2 += thres2.item()
            val_thres3 += thres3.item()

            # Save 3 images for visualization
            if not args.evaluate_only:
                if i in [
                        num_samples // 4, num_samples // 2,
                        num_samples // 4 * 3
                ]:
                    img_summary = dict()
                    img_summary['disp_error'] = disp_error_img(
                        pred_disp, gt_disp)
                    img_summary['left'] = left
                    img_summary['right'] = right
                    img_summary['gt_disp'] = gt_disp
                    img_summary['pred_disp'] = pred_disp
                    save_images(self.train_writer, 'val' + str(val_count),
                                img_summary, self.epoch)
                    val_count += 1

        logger.info('=> Validation done!')

        mean_epe = val_epe / valid_samples
        mean_d1 = val_d1 / valid_samples
        mean_thres1 = val_thres1 / valid_samples
        mean_thres2 = val_thres2 / valid_samples
        mean_thres3 = val_thres3 / valid_samples

        # Save validation results
        with open(val_file, 'a') as f:
            f.write('epoch: %03d\t' % self.epoch)
            f.write('epe: %.3f\t' % mean_epe)
            f.write('d1: %.4f\t' % mean_d1)
            f.write('thres1: %.4f\t' % mean_thres1)
            f.write('thres2: %.4f\t' % mean_thres2)
            f.write('thres3: %.4f\n' % mean_thres3)

        logger.info('=> Mean validation epe of epoch %d: %.3f' %
                    (self.epoch, mean_epe))

        if not args.evaluate_only:
            self.train_writer.add_scalar('val/epe', mean_epe, self.epoch)
            self.train_writer.add_scalar('val/d1', mean_d1, self.epoch)
            self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch)
            self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch)
            self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch)

        if not args.evaluate_only:
            if args.val_metric == 'd1':
                if mean_d1 < self.best_epe:
                    # Actually best_epe here is d1
                    self.best_epe = mean_d1
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_d1,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            elif args.val_metric == 'epe':
                if mean_epe < self.best_epe:
                    self.best_epe = mean_epe
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_epe,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            else:
                raise NotImplementedError

        if self.epoch == args.max_epoch:
            # Save best validation results
            with open(val_file, 'a') as f:
                f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

            logger.info('=> best epoch: %03d \t best %s: %.3f\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

        # Always save the latest model for resuming training
        if not args.evaluate_only:
            utils.save_checkpoint(args.checkpoint_dir,
                                  self.optimizer,
                                  self.aanet,
                                  epoch=self.epoch,
                                  num_iter=self.num_iter,
                                  epe=mean_epe,
                                  best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth')

            # Save checkpoint of specific epochs
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir,
                                      self.optimizer,
                                      self.aanet,
                                      epoch=self.epoch,
                                      num_iter=self.num_iter,
                                      epe=mean_epe,
                                      best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False)
예제 #3
0
    def train(self, train_loader, local_master, trainLoss_dict, trainLossKey):
        args = self.args
        logger = self.logger

        steps_per_epoch = len(train_loader) / args.accumulation_steps  # len(train_loader)返回的是Batch的个数
        device = self.device

        self.aanet.train()  # 设置模型为训练模式!

        if args.freeze_bn:
            def set_bn_eval(m):
                classname = m.__class__.__name__  # 实例调用__class__属性时会指向该实例对应的类。.__class__将实例变量指向类,然后再去调用__name__类属性
                if classname.find('BatchNorm') != -1:
                    m.eval()

            self.aanet.apply(
                set_bn_eval)  # apply(fn: Callable[Module, None]):Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).

        # Learning rate summary
        base_lr = self.optimizer.param_groups[0]['lr']
        offset_lr = self.optimizer.param_groups[1]['lr']
        self.train_writer.add_scalar('lr/base_lr', base_lr, self.epoch + 1)
        self.train_writer.add_scalar('lr/offset_lr', offset_lr, self.epoch + 1)

        last_print_time = time.time()

        validate_count = 0
        total_epe = 0
        total_d1 = 0
        total_thres1 = 0
        total_thres2 = 0
        total_thres3 = 0
        total_thres10 = 0
        total_thres20 = 0
        loss_acum = 0
        for i, sample in enumerate(train_loader):
            left = sample['left'].to(device)  # [B, 3, H, W]
            right = sample['right'].to(device)
            gt_disp = sample['disp'].to(device)  # [B, H, W]

            mask = (gt_disp > 0) & (gt_disp < args.max_disp)  # KITTI数据集约定:视差为0,表示无效视差。

            if args.load_pseudo_gt:
                pseudo_gt_disp = sample['pseudo_disp'].to(device)
                pseudo_mask = (pseudo_gt_disp > 0) & (pseudo_gt_disp < args.max_disp) & (
                    ~mask)  # inverse mask # 需要修补的像素位置的mask

            if not mask.any():  # np.array.any()是或操作,任意一个元素为True,输出为True。
                continue

            # 尝试分布式训练
            # 只在DDP模式下,轮数不是args.accumulation_steps整数倍的时候使用no_sync。
            # 博客:https://blog.csdn.net/a40850273/article/details/111829836
            my_context = self.aanet.no_sync if args.distributed and (
                        i + 1) % args.accumulation_steps != 0 else nullcontext
            with my_context():
                pred_disp_pyramid = self.aanet(left, right)  # list of H/12, H/6, H/3, H/2, H

                if args.highest_loss_only:
                    pred_disp_pyramid = [pred_disp_pyramid[-1]]  # only the last highest resolution output

                disp_loss = 0
                pseudo_disp_loss = 0
                pyramid_loss = []
                pseudo_pyramid_loss = []

                # Loss weights
                if len(pred_disp_pyramid) == 5:
                    pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0, 1.0]  # AANet and AANet+
                elif len(pred_disp_pyramid) == 4:
                    pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0]
                elif len(pred_disp_pyramid) == 3:
                    pyramid_weight = [1.0, 1.0, 1.0]  # 1 scale only
                elif len(pred_disp_pyramid) == 1:
                    pyramid_weight = [1.0]  # highest loss only
                else:
                    raise NotImplementedError

                assert len(pyramid_weight) == len(pred_disp_pyramid)
                for k in range(len(pred_disp_pyramid)):
                    pred_disp = pred_disp_pyramid[k]
                    weight = pyramid_weight[k]

                    if pred_disp.size(-1) != gt_disp.size(-1):
                        pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                        pred_disp = F.interpolate(pred_disp, size=(gt_disp.size(-2), gt_disp.size(-1)),
                                                  mode='bilinear', align_corners=False) * (
                                                gt_disp.size(-1) / pred_disp.size(-1))  # 最后乘上这一项是必须的。因为图像放大,视差要相应增大。
                        pred_disp = pred_disp.squeeze(1)  # [B, H, W]

                    curr_loss = F.smooth_l1_loss(pred_disp[mask], gt_disp[mask],
                                                 reduction='mean')
                    disp_loss += weight * curr_loss
                    pyramid_loss.append(curr_loss)

                    # Pseudo gt loss
                    if args.load_pseudo_gt:
                        pseudo_curr_loss = F.smooth_l1_loss(pred_disp[pseudo_mask], pseudo_gt_disp[pseudo_mask],
                                                            reduction='mean')
                        pseudo_disp_loss += weight * pseudo_curr_loss

                        pseudo_pyramid_loss.append(pseudo_curr_loss)

                total_loss = disp_loss + pseudo_disp_loss

                total_loss /= args.accumulation_steps
                total_loss.backward()

                # 仅用于记录和分析数据
                with torch.no_grad():
                    validate_count += 1
                    total_epe += F.l1_loss(gt_disp[mask], pred_disp_pyramid[-1][mask], reduction='mean').detach().cpu().numpy()
                    total_d1 += d1_metric(pred_disp_pyramid[-1], gt_disp, mask).detach().cpu().numpy()
                    total_thres1 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 1.0).detach().cpu().numpy()  # mask.shape=[B, H, W]
                    total_thres2 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 2.0).detach().cpu().numpy()
                    total_thres3 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 3.0).detach().cpu().numpy()
                    total_thres10 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 10.0).detach().cpu().numpy()
                    total_thres20 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 20.0).detach().cpu().numpy()
                    loss_acum += total_loss.detach().cpu().numpy()

            if (i + 1) % args.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                self.num_iter += 1

                if self.num_iter % args.print_freq == 0:
                    this_cycle = time.time() - last_print_time
                    last_print_time += this_cycle

                    time_to_finish = (args.max_epoch - self.epoch) * (1.0 * steps_per_epoch / args.print_freq) * \
                                     this_cycle / 3600.0  # 还有多久才能完成训练。单位:小时

                    logger.info('Epoch: [%3d/%3d] [%5d/%5d] time: %4.2fs remainT: %4.2fh disp_loss: %.3f' %
                                (self.epoch + 1, args.max_epoch, self.num_iter, steps_per_epoch, this_cycle,
                                 time_to_finish,
                                 disp_loss.item()))
                    # self.num_iter:表示当前一共进行了多少次迭代,一次参数更新表示一次迭代。
                    # steps_per_epoch:表示一个epoch中有多少次迭代,一次参数更新表示一次迭代。

                if self.num_iter % args.summary_freq == 0:
                    img_summary = dict()
                    img_summary['left'] = left  # [B, C=3, H, W]
                    img_summary['right'] = right  # [B, C=3, H, W]
                    img_summary['gt_disp'] = gt_disp  # [B, H, W]

                    if args.load_pseudo_gt:
                        img_summary['pseudo_gt_disp'] = pseudo_gt_disp

                    # Save pyramid disparity prediction
                    for s in range(len(pred_disp_pyramid)):
                        # Scale from low to high, reverse
                        save_name = 'pred_disp' + str(
                            len(pred_disp_pyramid) - s - 1)  # pred_disp0-->pred_disp4:高分辨率->低分辨率
                        save_value = pred_disp_pyramid[s]
                        img_summary[save_name] = save_value

                    pred_disp = pred_disp_pyramid[-1]

                    if pred_disp.size(-1) != gt_disp.size(-1):
                        pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                        pred_disp = F.interpolate(pred_disp, size=(gt_disp.size(-2), gt_disp.size(-1)),
                                                  mode='bilinear', align_corners=False) * (
                                                gt_disp.size(-1) / pred_disp.size(-1))
                        pred_disp = pred_disp.squeeze(1)  # [B, H, W]
                    img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)  # [B, C=3, H, W]

                    save_images(self.train_writer, 'train', img_summary, self.num_iter)

                    epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean')

                    self.train_writer.add_scalar('train/epe', epe.item(), self.num_iter)
                    self.train_writer.add_scalar('train/disp_loss', disp_loss.item(), self.num_iter)
                    self.train_writer.add_scalar('train/total_loss', total_loss.item(), self.num_iter)

                    # Save loss of different scale
                    for s in range(len(pyramid_loss)):
                        save_name = 'train/loss' + str(len(pyramid_loss) - s - 1)  # loss0-->loss4:低分辨率~高分辨率
                        save_value = pyramid_loss[s]
                        self.train_writer.add_scalar(save_name, save_value, self.num_iter)

                    d1 = d1_metric(pred_disp, gt_disp, mask)  # pred_disp.shape=[B, H, W], gt_disp.shape=[B, H, W]
                    self.train_writer.add_scalar('train/d1', d1.item(), self.num_iter)
                    thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)  # mask.shape=[B, H, W]
                    thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
                    thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)
                    thres10 = thres_metric(pred_disp, gt_disp, mask, 10.0)
                    thres20 = thres_metric(pred_disp, gt_disp, mask, 20.0)
                    self.train_writer.add_scalar('train/thres1', thres1.item(), self.num_iter)
                    self.train_writer.add_scalar('train/thres2', thres2.item(), self.num_iter)
                    self.train_writer.add_scalar('train/thres3', thres3.item(), self.num_iter)
                    self.train_writer.add_scalar('train/thres10', thres10.item(), self.num_iter)
                    self.train_writer.add_scalar('train/thres20', thres20.item(), self.num_iter)

        self.epoch += 1

        # 记录数据为matlab的mat文件,用于分析和对比
        trainLoss_dict[trainLossKey]['epochs'].append(self.epoch)
        trainLoss_dict[trainLossKey]['avgEPE'].append(total_epe / validate_count)
        trainLoss_dict[trainLossKey]['avg_d1'].append(total_d1 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres1'].append(total_thres1 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres2'].append(total_thres2 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres3'].append(total_thres3 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres10'].append(total_thres10 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres20'].append(total_thres20 / validate_count)
        trainLoss_dict[trainLossKey]['avg_loss'].append(loss_acum / validate_count)

        # 一个epoch结束:
        # args.no_validate=False,则后面不会做self.validate(),故需要在此处记录如下信息。
        # args.no_validate=True,则后面会做self.validate(),会在elf.validate()中记录如下信息,此处不必记录。
        # 需记录的信息包括:
        # 1.最新的训练的模型和状态(写入aanet_latest.pth、optimizer_latest.pth文件) for resuming training;
        # 2.Save checkpoint of specific epoch.
        # Always save the latest model for resuming training
        if args.no_validate:
            utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                  epoch=self.epoch, num_iter=self.num_iter,
                                  epe=-1, best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth') if local_master else None

            # Save checkpoint of specific epoch
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir, self.optimizer, self.aanet,
                                      epoch=self.epoch, num_iter=self.num_iter,
                                      epe=-1, best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False) if local_master else None
예제 #4
0
def test(dataloader, model, log):

    stages = 1#len(args.loss_weights)
    EPEs = [AverageMeter() for _ in range(stages)]
    thres1 = [AverageMeter() for _ in range(stages)]

    length_loader = len(dataloader)

    model.eval()

    padding_len = 16

    inference_time = 0

    for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
        imgL = imgL.float().cuda()
        imgR = imgR.float().cuda()
        disp_L = disp_L.float().cuda()

        mask = disp_L < args.maxdisp


        if imgL.shape[2] % padding_len != 0:
            times = imgL.shape[2]//padding_len
            #print("times:", times)
            if times % 2 == 0:
                top_pad = (times + 2) * padding_len - imgL.shape[2]

            else:
                top_pad = (times+1)*padding_len -imgL.shape[2]


        else:
            top_pad = 0


        right_pad = 0




        imgL = F.pad(imgL,(0,right_pad, top_pad,0))
        imgR = F.pad(imgR,(0,right_pad, top_pad,0))
        #print("After padding imgL size:", imgL.shape)


        with torch.no_grad():

            time_start = time.perf_counter()

            outputs = model(imgL, imgR)

            single_inference_time = time.perf_counter() - time_start

            #print("single_inference_time:", single_inference_time )

            inference_time += single_inference_time


            #print("output size:", outputs[0].shape)


            for x in range(stages):
                if len(disp_L[mask]) == 0:
                    EPEs[x].update(0)
                    thres1[x].update(0)
                    continue
                output = torch.squeeze(outputs[x], 1)

                if top_pad != 0:
                    output = output[:, top_pad:, :]
                else:
                    output = output

                if args.visualization and batch_idx <= 50:

                    #vis
                    # print("outputs", output.shape )
                    # print("disp_L size:", disp_L.shape)

                    # GT = disp_L[:, :, :]/255.0
                    # torchvision.utils.save_image(GT, join(args.save_path, "iter-%d.jpg" % batch_idx))


                    #output = output.cpu()
                    _, H, W = output.shape
                    # all_results_color = torch.zeros((H, 2*W))
                    GT_color = torch.zeros((H, W ))
                    GT_color[:, :W] = disp_L[0,:, :]
                    GT_color = cv.applyColorMap(np.array(GT_color * 2, dtype=np.uint8), cv.COLORMAP_JET)
                    cv.imwrite(join(args.save_path, "iter-%d_GT_color.jpg" % batch_idx), GT_color)


                    pred_color = torch.zeros((H, W ))
                    pred_color[:, :W] = output[0, :, :]


                    # all_results_color = torch.zeros((H, 2 * W + 20))
                    #
                    #
                    # all_results_color[:,:W]= output[0, :, :]
                    #
                    # #all_results_color[:, W:30] = output[0, :, :]
                    #
                    # all_results_color[:,W+20:2*W+20]= disp_L[0,:, :]

                    pred_color = cv.applyColorMap(np.array(pred_color*2, dtype=np.uint8), cv.COLORMAP_JET)
                    error = (output[mask] - disp_L[mask]).abs().mean()
                    cv.imwrite(join(args.save_path, "iter-%d_pred_color-%.3f.jpg" %( batch_idx , error)),pred_color)






                EPEs[x].update((output[mask] - disp_L[mask]).abs().mean())
                thres1[x].update(thres_metric(output, disp_L, mask, 1.0))

        if not batch_idx % args.print_freq:


            print("single_inference_time:", single_inference_time)
            info_str = '\t'.join(['Stage {} = {:.2f}({:.2f})'.format(x, EPEs[x].val, EPEs[x].avg) for x in range(stages)])

            log.info('EPEs  [{}/{}] {}'.format(batch_idx, length_loader, info_str))




            info_str = '\t'.join(['Stage {} = {:.3f}({:.3f})'.format(x, thres1[x].val, thres1[x].avg) for x in range(stages)])

            log.info('thres1  [{}/{}] {}'.format(batch_idx, length_loader, info_str))






    log.info(('=> Mean inference time for %d images: %.3fs' % (
        length_loader, inference_time / length_loader)))

    info_str = ', '.join(['Stage {}={:.2f}'.format(x, EPEs[x].avg) for x in range(stages)])
    log.info('Average test EPE = ' + info_str)


    info_str = ', '.join(['Stage {}={:.3f}'.format(x, thres1[x].avg) for x in range(stages)])
    log.info('Average test thres1 = ' + info_str)
예제 #5
0
    def validate(self, val_loader, local_master, valLossDict, valLossKey):
        args = self.args
        logger = self.logger
        logger.info('=> Start validation...')
        # 只做evaluate,则需要从文件加载训练好的模型。否则,直接使用本model类中保存的(尚未完成全部的Epoach训练的)self.aanet即可。
        if args.evaluate_only is True:
            if args.pretrained_aanet is not None:
                pretrained_aanet = args.pretrained_aanet
            else:
                model_name = 'aanet_best.pth'
                pretrained_aanet = os.path.join(args.checkpoint_dir, model_name)
                if not os.path.exists(pretrained_aanet):  # KITTI without validation
                    pretrained_aanet = pretrained_aanet.replace(model_name, 'aanet_latest.pth')

            logger.info('=> loading pretrained aanet: %s' % pretrained_aanet)
            utils.load_pretrained_net(self.aanet, pretrained_aanet, no_strict=True)

        self.aanet.eval()

        num_samples = len(val_loader)
        logger.info('=> %d samples found in the validation set' % num_samples)

        val_epe = 0
        val_d1 = 0
        val_thres1 = 0
        val_thres2 = 0
        val_thres3 = 0
        val_thres10 = 0
        val_thres20 = 0

        val_count = 0

        val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')

        num_imgs = 0
        valid_samples = 0

        # 遍历验证样本或测试样本
        for i, sample in enumerate(val_loader):
            if (i + 1) % 100 == 0:
                logger.info('=> Validating %d/%d' % (i, num_samples))

            left = sample['left'].to(self.device)  # [B, 3, H, W]
            right = sample['right'].to(self.device)
            gt_disp = sample['disp'].to(self.device)  # [B, H, W]
            mask = (gt_disp > 0) & (gt_disp < args.max_disp)

            if not mask.any():
                continue

            valid_samples += 1

            num_imgs += gt_disp.size(0)

            with torch.no_grad():
                disparity_pyramid = self.aanet(left, right)  # [B, H, W]
                pred_disp = disparity_pyramid[-1]

            if pred_disp.size(-1) < gt_disp.size(-1):
                pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                pred_disp = F.interpolate(pred_disp, (gt_disp.size(-2), gt_disp.size(-1)),
                                          mode='bilinear', align_corners=False) * (
                                        gt_disp.size(-1) / pred_disp.size(-1))
                pred_disp = pred_disp.squeeze(1)  # [B, H, W]

            epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean')
            d1 = d1_metric(pred_disp, gt_disp, mask)
            thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
            thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
            thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)
            thres10 = thres_metric(pred_disp, gt_disp, mask, 10.0)
            thres20 = thres_metric(pred_disp, gt_disp, mask, 20.0)

            val_epe += epe.item()
            val_d1 += d1.item()
            val_thres1 += thres1.item()
            val_thres2 += thres2.item()
            val_thres3 += thres3.item()
            val_thres10 += thres10.item()
            val_thres20 += thres20.item()

            # save Image For Error Analysis
            # saveForErrorAnalysis(index, img_name, dstPath, dstName, left, right, gt_disp, disparity_pyramid):
            with torch.no_grad():
                saveImgErrorAnalysis(i, sample['left_name'], './myDataAnalysis', 'SceneFlow_valIdx_{}'.format(i),
                                     left, right, gt_disp, disparity_pyramid, disp_error_img(pred_disp, gt_disp))

            # Save 3 images for visualization
            if not args.evaluate_only or args.mode == 'test':
                # if i in [num_samples // 4, num_samples // 2, num_samples // 4 * 3]:
                if i in [num_samples // 6, num_samples // 6 * 2, num_samples // 6 * 3, num_samples // 6 * 4,
                         num_samples // 6 * 5]:
                    img_summary = dict()
                    img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)
                    img_summary['left'] = left
                    img_summary['right'] = right
                    img_summary['gt_disp'] = gt_disp
                    img_summary['pred_disp'] = pred_disp
                    save_images(self.train_writer, 'val' + str(val_count), img_summary, self.epoch)

                    disp_error = disp_error_hist(pred_disp, gt_disp, args.max_disp)
                    save_hist(self.train_writer, '{}/{}'.format('val' + str(val_count), 'hist'), disp_error, self.epoch)

                    val_count += 1
        # 遍历验证样本或测试样本完成

        logger.info('=> Validation done!')

        mean_epe = val_epe / valid_samples
        mean_d1 = val_d1 / valid_samples
        mean_thres1 = val_thres1 / valid_samples
        mean_thres2 = val_thres2 / valid_samples
        mean_thres3 = val_thres3 / valid_samples
        mean_thres10 = val_thres10 / valid_samples
        mean_thres20 = val_thres20 / valid_samples

        # 记录数据为matlab的mat文件,用于分析和对比
        valLossDict[valLossKey]["epochs"].append(self.epoch)
        valLossDict[valLossKey]["avgEPE"].append(mean_epe)
        valLossDict[valLossKey]["avg_d1"].append(mean_d1)
        valLossDict[valLossKey]["avg_thres1"].append(mean_thres1)
        valLossDict[valLossKey]["avg_thres2"].append(mean_thres2)
        valLossDict[valLossKey]["avg_thres3"].append(mean_thres3)
        valLossDict[valLossKey]["avg_thres10"].append(mean_thres10)
        valLossDict[valLossKey]["avg_thres20"].append(mean_thres20)

        # Save validation results
        with open(val_file, 'a') as f:
            f.write('epoch: %03d\t' % self.epoch)
            f.write('epe: %.3f\t' % mean_epe)
            f.write('d1: %.4f\t' % mean_d1)
            f.write('thres1: %.4f\t' % mean_thres1)
            f.write('thres2: %.4f\t' % mean_thres2)
            f.write('thres3: %.4f\t' % mean_thres3)
            f.write('thres10: %.4f\t' % mean_thres10)
            f.write('thres20: %.4f\n' % mean_thres20)
            f.write('dataset_name= %s\t mode=%s\n' % (args.dataset_name, args.mode))

        logger.info('=> Mean validation epe of epoch %d: %.3f' % (self.epoch, mean_epe))

        if not args.evaluate_only:
            self.train_writer.add_scalar('val/epe', mean_epe, self.epoch)
            self.train_writer.add_scalar('val/d1', mean_d1, self.epoch)
            self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch)
            self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch)
            self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch)
            self.train_writer.add_scalar('val/thres10', mean_thres10, self.epoch)
            self.train_writer.add_scalar('val/thres20', mean_thres20, self.epoch)

        if not args.evaluate_only:
            if args.val_metric == 'd1':
                if mean_d1 < self.best_epe:
                    # Actually best_epe here is d1
                    self.best_epe = mean_d1
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                          epoch=self.epoch, num_iter=self.num_iter,
                                          epe=mean_d1, best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth') if local_master else None
            elif args.val_metric == 'epe':
                if mean_epe < self.best_epe:
                    self.best_epe = mean_epe
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                          epoch=self.epoch, num_iter=self.num_iter,
                                          epe=mean_epe, best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth') if local_master else None
            else:
                raise NotImplementedError

        if self.epoch == args.max_epoch:
            # Save best validation results
            with open(val_file, 'a') as f:
                f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' % (self.best_epoch,
                                                                     args.val_metric,
                                                                     self.best_epe))

            logger.info('=> best epoch: %03d \t best %s: %.3f\n' % (self.best_epoch,
                                                                    args.val_metric,
                                                                    self.best_epe))

        # Always save the latest model for resuming training
        if not args.evaluate_only:
            utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                  epoch=self.epoch, num_iter=self.num_iter,
                                  epe=mean_epe, best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth') if local_master else None

            # Save checkpoint of specific epochs
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir, self.optimizer, self.aanet,
                                      epoch=self.epoch, num_iter=self.num_iter,
                                      epe=mean_epe, best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False) if local_master else None