Пример #1
0
def main():
    board_writer = SummaryWriter(
        os.path.join(args.checkpoint_dir, "tensorBoard"))

    if args.model == 'PSMNet_stackhourglass':
        net = PSMNet_stackhourglass(args.max_disp)
    elif args.model == 'PSMNet_basic':
        net = PSMNet_basic(args.max_disp)
    else:
        print('no model')

    # Validation loader
    test_transform_list = [
        myTransforms.RandomCrop(args.test_img_height,
                                args.test_img_width,
                                validate=True),
        myTransforms.ToTensor(),
        myTransforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ]

    test_transform = myTransforms.Compose(test_transform_list)
    test_data = StereoDataset(data_dir=args.data_dir,
                              isDebug=args.isDebug,
                              dataset_name=args.dataset_name,
                              mode='test',
                              transform=test_transform)

    logger.info('=> {} test samples found in the test set'.format(
        len(test_data)))

    test_loader = DataLoader(dataset=test_data,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             pin_memory=True,
                             drop_last=False)

    net.cuda()

    if args.pretrained_net is not None:
        logger.info('=> Loading pretrained Net: %s' % args.pretrained_net)
        # Enable training from a partially pretrained model
        utils.load_pretrained_net(net,
                                  args.pretrained_net,
                                  strict=args.strict,
                                  logger=logger)
    else:
        logger.info('=>  args.pretrained_net is None! Please specify it!!!')
        return

    assert args.test_batch_size == 1, "test_batch_size must be 1."

    logger.info('=> Start testing...')
    testOnTestSet(net,
                  test_loader,
                  args.dataset_name,
                  board_writer,
                  mode="test",
                  epoch=0)
    def __init__(
        self,
        config,
        device,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=StereoMatcherBase.IMAGENET_MEAN,
                    std=StereoMatcherBase.IMAGENET_STD,
                ),
            ]
        ),
    ):

        StereoMatcherBase.__init__(self, config=config, device=device, transform=transform)

        self._model = nets.AANet(
            config.max_disp,
            num_downsample=config.num_downsample,
            feature_type=config.feature_type,
            no_feature_mdconv=config.no_feature_mdconv,
            feature_pyramid=config.feature_pyramid,
            feature_pyramid_network=config.feature_pyramid_network,
            feature_similarity=config.feature_similarity,
            aggregation_type=config.aggregation_type,
            num_scales=config.num_scales,
            num_fusions=config.num_fusions,
            num_stage_blocks=config.num_stage_blocks,
            num_deform_blocks=config.num_deform_blocks,
            no_intermediate_supervision=config.no_intermediate_supervision,
            refinement_type=config.refinement_type,
            mdconv_dilation=config.mdconv_dilation,
            deformable_groups=config.deformable_groups,
        )

        self._model = self._model.to(device)
        utils.load_pretrained_net(self._model, config.model_path, no_strict=True)

        self._model.eval()
Пример #3
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Test loader
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])
    test_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                         dataset_name=args.dataset_name,
                                         mode=args.mode,
                                         save_filename=True,
                                         transform=test_transform)
    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers, pin_memory=True, drop_last=False)

    aanet = nets.AANet(args.max_disp,
                       num_downsample=args.num_downsample,
                       feature_type=args.feature_type,
                       no_feature_mdconv=args.no_feature_mdconv,
                       feature_pyramid=args.feature_pyramid,
                       feature_pyramid_network=args.feature_pyramid_network,
                       feature_similarity=args.feature_similarity,
                       aggregation_type=args.aggregation_type,
                       num_scales=args.num_scales,
                       num_fusions=args.num_fusions,
                       num_stage_blocks=args.num_stage_blocks,
                       num_deform_blocks=args.num_deform_blocks,
                       no_intermediate_supervision=args.no_intermediate_supervision,
                       refinement_type=args.refinement_type,
                       mdconv_dilation=args.mdconv_dilation,
                       deformable_groups=args.deformable_groups).to(device)

    # print(aanet)

    if os.path.exists(args.pretrained_aanet):
        print('=> Loading pretrained AANet:', args.pretrained_aanet)
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=True)
    else:
        print('=> Using random initialization')

    # Save parameters
    num_params = utils.count_parameters(aanet)
    print('=> Number of trainable parameters: %d' % num_params)

    if torch.cuda.device_count() > 1:
        print('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Inference
    aanet.eval()

    inference_time = 0
    num_imgs = 0

    num_samples = len(test_loader)
    print('=> %d samples found in the test set' % num_samples)

    for i, sample in enumerate(test_loader):
        if args.count_time and i == args.num_images:  # testing time only
            break

        if i % 100 == 0:
            print('=> Inferencing %d/%d' % (i, num_samples))

        left = sample['left'].to(device)  # [B, 3, H, W]
        right = sample['right'].to(device)

        # Pad
        ori_height, ori_width = left.size()[2:]
        if ori_height < args.img_height or ori_width < args.img_width:
            top_pad = args.img_height - ori_height
            right_pad = args.img_width - ori_width

            # Pad size: (left_pad, right_pad, top_pad, bottom_pad)
            left = F.pad(left, (0, right_pad, top_pad, 0))
            right = F.pad(right, (0, right_pad, top_pad, 0))

        # Warmup
        if i == 0 and args.count_time:
            with torch.no_grad():
                for _ in range(10):
                    aanet(left, right)

        num_imgs += left.size(0)

        with torch.no_grad():
            time_start = time.perf_counter()
            pred_disp = aanet(left, right)[-1]  # [B, H, W]
            inference_time += time.perf_counter() - time_start

        if pred_disp.size(-1) < left.size(-1):
            pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
            pred_disp = F.interpolate(pred_disp, (left.size(-2), left.size(-1)),
                                      mode='bilinear', align_corners=True, recompute_scale_factor=True) * (left.size(-1) / pred_disp.size(-1))
            pred_disp = pred_disp.squeeze(1)  # [B, H, W]

        # Crop
        if ori_height < args.img_height or ori_width < args.img_width:
            if right_pad != 0:
                pred_disp = pred_disp[:, top_pad:, :-right_pad]
            else:
                pred_disp = pred_disp[:, top_pad:]

        for b in range(pred_disp.size(0)):
            disp = pred_disp[b].detach().cpu().numpy()  # [H, W]
            save_name = sample['left_name'][b]
            save_name = os.path.join(args.output_dir, save_name)
            utils.check_path(os.path.dirname(save_name))
            if not args.count_time:
                if args.save_type == 'pfm':
                    if args.visualize:
                        skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

                    save_name = save_name[:-3] + 'pfm'
                    write_pfm(save_name, disp)
                elif args.save_type == 'npy':
                    save_name = save_name[:-3] + 'npy'
                    np.save(save_name, disp)
                else:
                    skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

    print('=> Mean inference time for %d images: %.3fs' % (num_imgs, inference_time / num_imgs))
Пример #4
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Train loader
    train_transform_list = [transforms.RandomCrop(args.img_height, args.img_width),
                            transforms.RandomColor(),
                            transforms.RandomVerticalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
                            ]
    train_transform = transforms.Compose(train_transform_list)

    train_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                          dataset_name=args.dataset_name,
                                          mode='train' if args.mode != 'train_all' else 'train_all',
                                          load_pseudo_gt=args.load_pseudo_gt,
                                          transform=train_transform)

    logger.info('=> {} training samples found in the training set'.format(len(train_data)))

    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True)

    # Validation loader
    val_transform_list = [transforms.RandomCrop(args.val_img_height, args.val_img_width, validate=True),
                          transforms.ToTensor(),
                          transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
                         ]
    val_transform = transforms.Compose(val_transform_list)
    val_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                        dataset_name=args.dataset_name,
                                        mode=args.mode,
                                        transform=val_transform)

    val_loader = DataLoader(dataset=val_data, batch_size=args.val_batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True, drop_last=False)

    # Network
    aanet = nets.AANet(args.max_disp,
                       num_downsample=args.num_downsample,
                       feature_type=args.feature_type,
                       no_feature_mdconv=args.no_feature_mdconv,
                       feature_pyramid=args.feature_pyramid,
                       feature_pyramid_network=args.feature_pyramid_network,
                       feature_similarity=args.feature_similarity,
                       aggregation_type=args.aggregation_type,
                       num_scales=args.num_scales,
                       num_fusions=args.num_fusions,
                       num_stage_blocks=args.num_stage_blocks,
                       num_deform_blocks=args.num_deform_blocks,
                       no_intermediate_supervision=args.no_intermediate_supervision,
                       refinement_type=args.refinement_type,
                       mdconv_dilation=args.mdconv_dilation,
                       deformable_groups=args.deformable_groups).to(device)

    logger.info('%s' % aanet)

    if args.pretrained_aanet is not None:
        logger.info('=> Loading pretrained AANet: %s' % args.pretrained_aanet)
        # Enable training from a partially pretrained model
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=(not args.strict))

    if torch.cuda.device_count() > 1:
        logger.info('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Save parameters
    num_params = utils.count_parameters(aanet)
    logger.info('=> Number of trainable parameters: %d' % num_params)
    save_name = '%d_parameters' % num_params
    open(os.path.join(args.checkpoint_dir, save_name), 'a').close()

    # Optimizer
    # Learning rate for offset learning is set 0.1 times those of existing layers
    specific_params = list(filter(utils.filter_specific_params,
                                  aanet.named_parameters()))
    base_params = list(filter(utils.filter_base_params,
                              aanet.named_parameters()))

    specific_params = [kv[1] for kv in specific_params]  # kv is a tuple (key, value)
    base_params = [kv[1] for kv in base_params]

    specific_lr = args.learning_rate * 0.1
    params_group = [
        {'params': base_params, 'lr': args.learning_rate},
        {'params': specific_params, 'lr': specific_lr},
    ]

    optimizer = torch.optim.Adam(params_group, weight_decay=args.weight_decay)

    # Resume training
    if args.resume:
        # AANet
        start_epoch, start_iter, best_epe, best_epoch = utils.resume_latest_ckpt(
            args.checkpoint_dir, aanet, 'aanet')

        # Optimizer
        utils.resume_latest_ckpt(args.checkpoint_dir, optimizer, 'optimizer')
    else:
        start_epoch = 0
        start_iter = 0
        best_epe = None
        best_epoch = None

    # LR scheduler
    if args.lr_scheduler_type is not None:
        last_epoch = start_epoch if args.resume else start_epoch - 1
        if args.lr_scheduler_type == 'MultiStepLR':
            milestones = [int(step) for step in args.milestones.split(',')]
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                                milestones=milestones,
                                                                gamma=args.lr_decay_gamma,
                                                                last_epoch=last_epoch)
        else:
            raise NotImplementedError

    train_model = model.Model(args, logger, optimizer, aanet, device, start_iter, start_epoch,
                              best_epe=best_epe, best_epoch=best_epoch)

    logger.info('=> Start training...')

    if args.evaluate_only:
        assert args.val_batch_size == 1
        train_model.validate(val_loader)
    else:
        for _ in range(start_epoch, args.max_epoch):
            if not args.evaluate_only:
                train_model.train(train_loader)
            if not args.no_validate:
                train_model.validate(val_loader)
            if args.lr_scheduler_type is not None:
                lr_scheduler.step()

        logger.info('=> End training\n\n')
Пример #5
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Test loader
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

    aanet = nets.AANet(
        args.max_disp,
        num_downsample=args.num_downsample,
        feature_type=args.feature_type,
        no_feature_mdconv=args.no_feature_mdconv,
        feature_pyramid=args.feature_pyramid,
        feature_pyramid_network=args.feature_pyramid_network,
        feature_similarity=args.feature_similarity,
        aggregation_type=args.aggregation_type,
        num_scales=args.num_scales,
        num_fusions=args.num_fusions,
        num_stage_blocks=args.num_stage_blocks,
        num_deform_blocks=args.num_deform_blocks,
        no_intermediate_supervision=args.no_intermediate_supervision,
        refinement_type=args.refinement_type,
        mdconv_dilation=args.mdconv_dilation,
        deformable_groups=args.deformable_groups).to(device)

    if os.path.exists(args.pretrained_aanet):
        print('=> Loading pretrained AANet:', args.pretrained_aanet)
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=True)
    else:
        print('=> Using random initialization')

    if torch.cuda.device_count() > 1:
        print('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Inference
    aanet.eval()

    if args.data_dir.endswith('/'):
        args.data_dir = args.data_dir[:-1]

    # all_samples = sorted(glob(args.data_dir + '/*left.png'))
    all_samples = sorted(glob(args.data_dir + '/left/*.png'))

    num_samples = len(all_samples)
    print('=> %d samples found in the data dir' % num_samples)

    for i, sample_name in enumerate(all_samples):
        if i % 100 == 0:
            print('=> Inferencing %d/%d' % (i, num_samples))

        left_name = sample_name

        right_name = left_name.replace('left', 'right')

        left = read_img(left_name)
        right = read_img(right_name)
        sample = {'left': left, 'right': right}
        sample = test_transform(sample)  # to tensor and normalize

        left = sample['left'].to(device)  # [3, H, W]
        left = left.unsqueeze(0)  # [1, 3, H, W]
        right = sample['right'].to(device)
        right = right.unsqueeze(0)

        # Pad
        ori_height, ori_width = left.size()[2:]

        # Automatic
        factor = 48
        args.img_height = math.ceil(ori_height / factor) * factor
        args.img_width = math.ceil(ori_width / factor) * factor

        if ori_height < args.img_height or ori_width < args.img_width:
            top_pad = args.img_height - ori_height
            right_pad = args.img_width - ori_width

            # Pad size: (left_pad, right_pad, top_pad, bottom_pad)
            left = F.pad(left, (0, right_pad, top_pad, 0))
            right = F.pad(right, (0, right_pad, top_pad, 0))

        with torch.no_grad():
            pred_disp = aanet(left, right)[-1]  # [B, H, W]

        if pred_disp.size(-1) < left.size(-1):
            pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
            pred_disp = F.interpolate(
                pred_disp, (left.size(-2), left.size(-1)),
                mode='bilinear') * (left.size(-1) / pred_disp.size(-1))
            pred_disp = pred_disp.squeeze(1)  # [B, H, W]

        # Crop
        if ori_height < args.img_height or ori_width < args.img_width:
            if right_pad != 0:
                pred_disp = pred_disp[:, top_pad:, :-right_pad]
            else:
                pred_disp = pred_disp[:, top_pad:]

        disp = pred_disp[0].detach().cpu().numpy()  # [H, W]

        save_name = os.path.basename(
            left_name)[:-4] + '_' + args.save_suffix + '.png'
        save_name = os.path.join(args.output_dir, save_name)

        if args.save_type == 'pfm':
            if args.visualize:
                skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

            save_name = save_name[:-3] + 'pfm'
            write_pfm(save_name, disp)
        elif args.save_type == 'npy':
            save_name = save_name[:-3] + 'npy'
            np.save(save_name, disp)
        else:
            skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))
Пример #6
0
    def validate(self, val_loader):
        args = self.args
        logger = self.logger
        logger.info('=> Start validation...')

        if args.evaluate_only is True:
            if args.pretrained_aanet is not None:
                pretrained_aanet = args.pretrained_aanet
            else:
                model_name = 'aanet_best.pth'
                pretrained_aanet = os.path.join(args.checkpoint_dir,
                                                model_name)
                if not os.path.exists(
                        pretrained_aanet):  # KITTI without validation
                    pretrained_aanet = pretrained_aanet.replace(
                        model_name, 'aanet_latest.pth')

            logger.info('=> loading pretrained aanet: %s' % pretrained_aanet)
            utils.load_pretrained_net(self.aanet,
                                      pretrained_aanet,
                                      no_strict=True)

        self.aanet.eval()

        num_samples = len(val_loader)
        logger.info('=> %d samples found in the validation set' % num_samples)

        val_epe = 0
        val_d1 = 0
        val_thres1 = 0
        val_thres2 = 0
        val_thres3 = 0

        val_count = 0

        val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')

        num_imgs = 0
        valid_samples = 0

        for i, sample in enumerate(val_loader):
            if i % 100 == 0:
                logger.info('=> Validating %d/%d' % (i, num_samples))

            left = sample['left'].to(self.device)  # [B, 3, H, W]
            right = sample['right'].to(self.device)
            gt_disp = sample['disp'].to(self.device)  # [B, H, W]
            mask = (gt_disp > 0) & (gt_disp < args.max_disp)

            if not mask.any():
                continue

            valid_samples += 1

            num_imgs += gt_disp.size(0)

            with torch.no_grad():
                pred_disp = self.aanet(left, right)[-1]  # [B, H, W]

            if pred_disp.size(-1) < gt_disp.size(-1):
                pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                pred_disp = F.interpolate(
                    pred_disp, (gt_disp.size(-2), gt_disp.size(-1)),
                    mode='bilinear') * (gt_disp.size(-1) / pred_disp.size(-1))
                pred_disp = pred_disp.squeeze(1)  # [B, H, W]

            epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean')
            d1 = d1_metric(pred_disp, gt_disp, mask)
            thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
            thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
            thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)

            val_epe += epe.item()
            val_d1 += d1.item()
            val_thres1 += thres1.item()
            val_thres2 += thres2.item()
            val_thres3 += thres3.item()

            # Save 3 images for visualization
            if not args.evaluate_only:
                if i in [
                        num_samples // 4, num_samples // 2,
                        num_samples // 4 * 3
                ]:
                    img_summary = dict()
                    img_summary['disp_error'] = disp_error_img(
                        pred_disp, gt_disp)
                    img_summary['left'] = left
                    img_summary['right'] = right
                    img_summary['gt_disp'] = gt_disp
                    img_summary['pred_disp'] = pred_disp
                    save_images(self.train_writer, 'val' + str(val_count),
                                img_summary, self.epoch)
                    val_count += 1

        logger.info('=> Validation done!')

        mean_epe = val_epe / valid_samples
        mean_d1 = val_d1 / valid_samples
        mean_thres1 = val_thres1 / valid_samples
        mean_thres2 = val_thres2 / valid_samples
        mean_thres3 = val_thres3 / valid_samples

        # Save validation results
        with open(val_file, 'a') as f:
            f.write('epoch: %03d\t' % self.epoch)
            f.write('epe: %.3f\t' % mean_epe)
            f.write('d1: %.4f\t' % mean_d1)
            f.write('thres1: %.4f\t' % mean_thres1)
            f.write('thres2: %.4f\t' % mean_thres2)
            f.write('thres3: %.4f\n' % mean_thres3)

        logger.info('=> Mean validation epe of epoch %d: %.3f' %
                    (self.epoch, mean_epe))

        if not args.evaluate_only:
            self.train_writer.add_scalar('val/epe', mean_epe, self.epoch)
            self.train_writer.add_scalar('val/d1', mean_d1, self.epoch)
            self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch)
            self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch)
            self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch)

        if not args.evaluate_only:
            if args.val_metric == 'd1':
                if mean_d1 < self.best_epe:
                    # Actually best_epe here is d1
                    self.best_epe = mean_d1
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_d1,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            elif args.val_metric == 'epe':
                if mean_epe < self.best_epe:
                    self.best_epe = mean_epe
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_epe,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            else:
                raise NotImplementedError

        if self.epoch == args.max_epoch:
            # Save best validation results
            with open(val_file, 'a') as f:
                f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

            logger.info('=> best epoch: %03d \t best %s: %.3f\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

        # Always save the latest model for resuming training
        if not args.evaluate_only:
            utils.save_checkpoint(args.checkpoint_dir,
                                  self.optimizer,
                                  self.aanet,
                                  epoch=self.epoch,
                                  num_iter=self.num_iter,
                                  epe=mean_epe,
                                  best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth')

            # Save checkpoint of specific epochs
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir,
                                      self.optimizer,
                                      self.aanet,
                                      epoch=self.epoch,
                                      num_iter=self.num_iter,
                                      epe=mean_epe,
                                      best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False)
Пример #7
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    train_loader, val_loader = getDataLoader(args, logger)

    # Network
    aanet = nets.AANet(
        args.max_disp,
        num_downsample=args.num_downsample,
        feature_type=args.feature_type,
        no_feature_mdconv=args.no_feature_mdconv,
        feature_pyramid=args.feature_pyramid,
        feature_pyramid_network=args.feature_pyramid_network,
        feature_similarity=args.feature_similarity,
        aggregation_type=args.aggregation_type,
        useFeatureAtt=args.useFeatureAtt,
        num_scales=args.num_scales,
        num_fusions=args.num_fusions,
        num_stage_blocks=args.num_stage_blocks,
        num_deform_blocks=args.num_deform_blocks,
        no_intermediate_supervision=args.no_intermediate_supervision,
        refinement_type=args.refinement_type,
        mdconv_dilation=args.mdconv_dilation,
        deformable_groups=args.deformable_groups).to(device)

    # logger.info('%s' % aanet) if local_master else None
    if local_master:
        structure_of_net = os.path.join(args.checkpoint_dir,
                                        'structure_of_net.txt')
        with open(structure_of_net, 'w') as f:
            f.write('%s' % aanet)

    if args.pretrained_aanet is not None:
        logger.info('=> Loading pretrained AANet: %s' % args.pretrained_aanet)
        # Enable training from a partially pretrained model
        utils.load_pretrained_net(aanet,
                                  args.pretrained_aanet,
                                  no_strict=(not args.strict))

    aanet.to(device)
    logger.info('=> Use %d GPUs' %
                torch.cuda.device_count()) if local_master else None
    # if torch.cuda.device_count() > 1:
    if args.distributed:
        # aanet = torch.nn.DataParallel(aanet)
        #  尝试分布式训练
        aanet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(aanet)
        aanet = torch.nn.parallel.DistributedDataParallel(
            aanet, device_ids=[local_rank], output_device=local_rank)
        synchronize()

    # Save parameters
    num_params = utils.count_parameters(aanet)
    logger.info('=> Number of trainable parameters: %d' % num_params)
    save_name = '%d_parameters' % num_params
    open(os.path.join(args.checkpoint_dir, save_name), 'a').close(
    ) if local_master else None  # 这是个空文件,只是通过其文件名称指示模型有多少个需要训练的参数

    # Optimizer
    # Learning rate for offset learning is set 0.1 times those of existing layers
    specific_params = list(
        filter(utils.filter_specific_params, aanet.named_parameters()))
    base_params = list(
        filter(utils.filter_base_params, aanet.named_parameters()))

    specific_params = [kv[1]
                       for kv in specific_params]  # kv is a tuple (key, value)
    base_params = [kv[1] for kv in base_params]

    specific_lr = args.learning_rate * 0.1
    params_group = [
        {
            'params': base_params,
            'lr': args.learning_rate
        },
        {
            'params': specific_params,
            'lr': specific_lr
        },
    ]

    optimizer = torch.optim.Adam(params_group, weight_decay=args.weight_decay)

    # Resume training
    if args.resume:
        # 1. resume AANet
        start_epoch, start_iter, best_epe, best_epoch = utils.resume_latest_ckpt(
            args.checkpoint_dir, aanet, 'aanet')
        # 2. resume Optimizer
        utils.resume_latest_ckpt(args.checkpoint_dir, optimizer, 'optimizer')
    else:
        start_epoch = 0
        start_iter = 0
        best_epe = None
        best_epoch = None

    # LR scheduler
    if args.lr_scheduler_type is not None:
        last_epoch = start_epoch if args.resume else start_epoch - 1
        if args.lr_scheduler_type == 'MultiStepLR':
            milestones = [int(step) for step in args.milestones.split(',')]
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=milestones,
                gamma=args.lr_decay_gamma,
                last_epoch=last_epoch
            )  # 最后这个last_epoch参数很重要:如果是resume的话,则会自动调整学习率适去应last_epoch。
        else:
            raise NotImplementedError
    # model.Model(object)对AANet做了进一步封装。
    train_model = model.Model(args,
                              logger,
                              optimizer,
                              aanet,
                              device,
                              start_iter,
                              start_epoch,
                              best_epe=best_epe,
                              best_epoch=best_epoch)

    logger.info('=> Start training...')

    trainLoss_dict, trainLossKey, valLoss_dict, valLossKey = getLossRecord(
        netName="AANet")

    if args.evaluate_only:
        assert args.val_batch_size == 1
        train_model.validate(
            val_loader, local_master, valLoss_dict,
            valLossKey)  # test模式。应该设置--evaluate_only,且--mode为“test”。
        # 保存Loss用于分析
        save_loss_for_matlab(trainLoss_dict, valLoss_dict)
    else:
        for epoch in range(start_epoch, args.max_epoch):  # 训练主循环(Epochs)!!!
            if not args.evaluate_only:
                # ensure distribute worker sample different data,
                # set different random seed by passing epoch to sampler
                if args.distributed:
                    train_loader.sampler.set_epoch(epoch)
                    logger.info(
                        'train_loader.sampler.set_epoch({})'.format(epoch))
                train_model.train(train_loader, local_master, trainLoss_dict,
                                  trainLossKey)
            if not args.no_validate:
                train_model.validate(val_loader, local_master, valLoss_dict,
                                     valLossKey)  # 训练模式下:边训练边验证。
            if args.lr_scheduler_type is not None:
                lr_scheduler.step()  # 调整Learning Rate

            # 保存Loss用于分析。每个epoch结束后,都保存一次,覆盖之前的保存。避免必须训练完成才保存的弊端。
            save_loss_for_matlab(trainLoss_dict, valLoss_dict)

        logger.info('=> End training\n\n')
Пример #8
0
    def validate(self, val_loader, local_master, valLossDict, valLossKey):
        args = self.args
        logger = self.logger
        logger.info('=> Start validation...')
        # 只做evaluate,则需要从文件加载训练好的模型。否则,直接使用本model类中保存的(尚未完成全部的Epoach训练的)self.aanet即可。
        if args.evaluate_only is True:
            if args.pretrained_aanet is not None:
                pretrained_aanet = args.pretrained_aanet
            else:
                model_name = 'aanet_best.pth'
                pretrained_aanet = os.path.join(args.checkpoint_dir, model_name)
                if not os.path.exists(pretrained_aanet):  # KITTI without validation
                    pretrained_aanet = pretrained_aanet.replace(model_name, 'aanet_latest.pth')

            logger.info('=> loading pretrained aanet: %s' % pretrained_aanet)
            utils.load_pretrained_net(self.aanet, pretrained_aanet, no_strict=True)

        self.aanet.eval()

        num_samples = len(val_loader)
        logger.info('=> %d samples found in the validation set' % num_samples)

        val_epe = 0
        val_d1 = 0
        val_thres1 = 0
        val_thres2 = 0
        val_thres3 = 0
        val_thres10 = 0
        val_thres20 = 0

        val_count = 0

        val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')

        num_imgs = 0
        valid_samples = 0

        # 遍历验证样本或测试样本
        for i, sample in enumerate(val_loader):
            if (i + 1) % 100 == 0:
                logger.info('=> Validating %d/%d' % (i, num_samples))

            left = sample['left'].to(self.device)  # [B, 3, H, W]
            right = sample['right'].to(self.device)
            gt_disp = sample['disp'].to(self.device)  # [B, H, W]
            mask = (gt_disp > 0) & (gt_disp < args.max_disp)

            if not mask.any():
                continue

            valid_samples += 1

            num_imgs += gt_disp.size(0)

            with torch.no_grad():
                disparity_pyramid = self.aanet(left, right)  # [B, H, W]
                pred_disp = disparity_pyramid[-1]

            if pred_disp.size(-1) < gt_disp.size(-1):
                pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                pred_disp = F.interpolate(pred_disp, (gt_disp.size(-2), gt_disp.size(-1)),
                                          mode='bilinear', align_corners=False) * (
                                        gt_disp.size(-1) / pred_disp.size(-1))
                pred_disp = pred_disp.squeeze(1)  # [B, H, W]

            epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean')
            d1 = d1_metric(pred_disp, gt_disp, mask)
            thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
            thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
            thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)
            thres10 = thres_metric(pred_disp, gt_disp, mask, 10.0)
            thres20 = thres_metric(pred_disp, gt_disp, mask, 20.0)

            val_epe += epe.item()
            val_d1 += d1.item()
            val_thres1 += thres1.item()
            val_thres2 += thres2.item()
            val_thres3 += thres3.item()
            val_thres10 += thres10.item()
            val_thres20 += thres20.item()

            # save Image For Error Analysis
            # saveForErrorAnalysis(index, img_name, dstPath, dstName, left, right, gt_disp, disparity_pyramid):
            with torch.no_grad():
                saveImgErrorAnalysis(i, sample['left_name'], './myDataAnalysis', 'SceneFlow_valIdx_{}'.format(i),
                                     left, right, gt_disp, disparity_pyramid, disp_error_img(pred_disp, gt_disp))

            # Save 3 images for visualization
            if not args.evaluate_only or args.mode == 'test':
                # if i in [num_samples // 4, num_samples // 2, num_samples // 4 * 3]:
                if i in [num_samples // 6, num_samples // 6 * 2, num_samples // 6 * 3, num_samples // 6 * 4,
                         num_samples // 6 * 5]:
                    img_summary = dict()
                    img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)
                    img_summary['left'] = left
                    img_summary['right'] = right
                    img_summary['gt_disp'] = gt_disp
                    img_summary['pred_disp'] = pred_disp
                    save_images(self.train_writer, 'val' + str(val_count), img_summary, self.epoch)

                    disp_error = disp_error_hist(pred_disp, gt_disp, args.max_disp)
                    save_hist(self.train_writer, '{}/{}'.format('val' + str(val_count), 'hist'), disp_error, self.epoch)

                    val_count += 1
        # 遍历验证样本或测试样本完成

        logger.info('=> Validation done!')

        mean_epe = val_epe / valid_samples
        mean_d1 = val_d1 / valid_samples
        mean_thres1 = val_thres1 / valid_samples
        mean_thres2 = val_thres2 / valid_samples
        mean_thres3 = val_thres3 / valid_samples
        mean_thres10 = val_thres10 / valid_samples
        mean_thres20 = val_thres20 / valid_samples

        # 记录数据为matlab的mat文件,用于分析和对比
        valLossDict[valLossKey]["epochs"].append(self.epoch)
        valLossDict[valLossKey]["avgEPE"].append(mean_epe)
        valLossDict[valLossKey]["avg_d1"].append(mean_d1)
        valLossDict[valLossKey]["avg_thres1"].append(mean_thres1)
        valLossDict[valLossKey]["avg_thres2"].append(mean_thres2)
        valLossDict[valLossKey]["avg_thres3"].append(mean_thres3)
        valLossDict[valLossKey]["avg_thres10"].append(mean_thres10)
        valLossDict[valLossKey]["avg_thres20"].append(mean_thres20)

        # Save validation results
        with open(val_file, 'a') as f:
            f.write('epoch: %03d\t' % self.epoch)
            f.write('epe: %.3f\t' % mean_epe)
            f.write('d1: %.4f\t' % mean_d1)
            f.write('thres1: %.4f\t' % mean_thres1)
            f.write('thres2: %.4f\t' % mean_thres2)
            f.write('thres3: %.4f\t' % mean_thres3)
            f.write('thres10: %.4f\t' % mean_thres10)
            f.write('thres20: %.4f\n' % mean_thres20)
            f.write('dataset_name= %s\t mode=%s\n' % (args.dataset_name, args.mode))

        logger.info('=> Mean validation epe of epoch %d: %.3f' % (self.epoch, mean_epe))

        if not args.evaluate_only:
            self.train_writer.add_scalar('val/epe', mean_epe, self.epoch)
            self.train_writer.add_scalar('val/d1', mean_d1, self.epoch)
            self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch)
            self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch)
            self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch)
            self.train_writer.add_scalar('val/thres10', mean_thres10, self.epoch)
            self.train_writer.add_scalar('val/thres20', mean_thres20, self.epoch)

        if not args.evaluate_only:
            if args.val_metric == 'd1':
                if mean_d1 < self.best_epe:
                    # Actually best_epe here is d1
                    self.best_epe = mean_d1
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                          epoch=self.epoch, num_iter=self.num_iter,
                                          epe=mean_d1, best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth') if local_master else None
            elif args.val_metric == 'epe':
                if mean_epe < self.best_epe:
                    self.best_epe = mean_epe
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                          epoch=self.epoch, num_iter=self.num_iter,
                                          epe=mean_epe, best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth') if local_master else None
            else:
                raise NotImplementedError

        if self.epoch == args.max_epoch:
            # Save best validation results
            with open(val_file, 'a') as f:
                f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' % (self.best_epoch,
                                                                     args.val_metric,
                                                                     self.best_epe))

            logger.info('=> best epoch: %03d \t best %s: %.3f\n' % (self.best_epoch,
                                                                    args.val_metric,
                                                                    self.best_epe))

        # Always save the latest model for resuming training
        if not args.evaluate_only:
            utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                  epoch=self.epoch, num_iter=self.num_iter,
                                  epe=mean_epe, best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth') if local_master else None

            # Save checkpoint of specific epochs
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir, self.optimizer, self.aanet,
                                      epoch=self.epoch, num_iter=self.num_iter,
                                      epe=mean_epe, best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False) if local_master else None
Пример #9
0
    def validate(self, val_loader):
        args = self.args
        logger = self.logger
        logger.info('=> Start validation...')

        if args.evaluate_only is True:
            if args.pretrained_aanet is not None:
                pretrained_aanet = args.pretrained_aanet
            else:
                model_name = 'aanet_best.pth'
                pretrained_aanet = os.path.join(args.checkpoint_dir,
                                                model_name)
                if not os.path.exists(
                        pretrained_aanet):  # KITTI without validation
                    pretrained_aanet = pretrained_aanet.replace(
                        model_name, 'aanet_latest.pth')

            logger.info('=> loading pretrained aanet: %s' % pretrained_aanet)
            utils.load_pretrained_net(self.aanet,
                                      pretrained_aanet,
                                      no_strict=True)

        self.aanet.train()

        num_samples = len(val_loader)
        logger.info('=> %d samples found in the validation set' % num_samples)

        val_epe = 0
        val_d1 = 0
        # val_thres1 = 0
        # val_thres2 = 0
        # val_thres3 = 0
        val_bad1 = 0
        val_bad2 = 0
        val_abs = 0
        val_mm2 = 0
        val_mm4 = 0
        val_mm8 = 0

        val_count = 0

        val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')

        num_imgs = 0
        valid_samples = 0

        baseline = 0.055
        intrinsic = [[1387.095, 0.0, 960.0], [0.0, 1387.095, 540.0],
                     [0.0, 0.0, 1.0]]

        for i, sample in enumerate(val_loader):
            if i % 100 == 0:
                logger.info('=> Validating %d/%d' % (i, num_samples))

            left = sample['left'].to(self.device)  # [B, 3, H, W]
            right = sample['right'].to(self.device)
            gt_disp = sample['disp'].to(self.device)  # [B, H, W]
            gt_depth = []
            pred_disp = []

            if args.dataset_name == 'custom_dataset':  # going to be depthL_fromR_down if from  custom_dataset
                gt_disp_1 = (baseline * 1000 * intrinsic[0][0] /
                             2) / (gt_disp * 256.)
                gt_disp_1[gt_disp_1 == inf] = 0
                gt_depth = gt_disp * 256.
                gt_disp = gt_disp_1

            if (args.dataset_name == 'custom_dataset_full'
                    or args.dataset_name == 'custom_dataset_obj'):

                # convert to disparity then apply warp ops
                temp = gt_disp * 256.
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    temp[x] = (baseline * 1000 * intrinsic[0][0] /
                               2) / (temp[x])
                    temp[x][temp[x] == inf] = 0

                # gt_disp = torch.clone(temp)
                gt_disp = apply_disparity_cu(temp.unsqueeze(1),
                                             temp.type(torch.int))
                gt_disp = torch.squeeze(gt_disp)

                gt_depth = temp
                # convert to gt_depth
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    gt_depth[x] = (baseline * 1000 * intrinsic[0][0] /
                                   2) / (gt_disp[x])
                    gt_depth[x][gt_depth[x] == inf] = 0
                gt_depth = gt_depth.to(self.device)

            if (args.dataset_name == 'custom_dataset_sim'
                    or args.dataset_name == 'custom_dataset_real'):
                temp = gt_disp * 256.
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    temp[x] = (baseline * 1000 * intrinsic[0][0] /
                               2) / (temp[x])
                    temp[x][temp[x] == inf] = 0

                # gt_disp = torch.clone(temp)
                gt_disp = apply_disparity_cu(temp.unsqueeze(1),
                                             temp.type(torch.int))
                gt_disp = torch.squeeze(gt_disp)

                gt_disp = torch.unsqueeze(gt_disp, 0)

                gt_depth = temp
                # convert to gt_depth
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    gt_depth[x] = (baseline * 1000 * intrinsic[0][0] /
                                   2) / (gt_disp[x])
                    gt_depth[x][gt_depth[x] == inf] = 0
                gt_depth = gt_depth.to(self.device)

            mask_disp = (gt_disp > 0.) & (gt_disp < args.max_disp)

            if not mask_disp.any():
                continue

            valid_samples += 1

            num_imgs += gt_disp.size(0)

            with torch.no_grad():
                pred_disp = self.aanet(left, right)[-1]  # [B, H, W]

            if pred_disp.size(-1) < gt_disp.size(-1):
                pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                pred_disp = F.interpolate(
                    pred_disp, (gt_disp.size(-2), gt_disp.size(-1)),
                    mode='bilinear',
                    align_corners=False) * (gt_disp.size(-1) /
                                            pred_disp.size(-1))
                pred_disp = pred_disp.squeeze(1)  # [B, H, W]

            if (onlyObj):
                gt_disp[sample['label'] >= 17] = 0
                mask_disp = (gt_disp > 0.) & (gt_disp < args.max_disp)

            epe = F.l1_loss(gt_disp[mask_disp],
                            pred_disp[mask_disp],
                            reduction='mean')
            d1 = d1_metric(pred_disp, gt_disp, mask_disp)

            bad1 = bad(pred_disp, gt_disp, mask_disp)
            bad2 = bad(pred_disp, gt_disp, mask_disp, threshold=2)

            pred_depth = []
            if (args.dataset_name == 'custom_dataset_full'
                    or args.dataset_name == 'custom_dataset_sim'
                    or args.dataset_name == 'custom_dataset_real'
                    or args.dataset_name == 'custom_dataset_obj'):
                temp = torch.zeros((pred_disp.shape)).to(self.device)
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    temp[x] = (baseline * 1000 * intrinsic[0][0] /
                               2) / (pred_disp[x])
                    temp[x][temp[x] == inf] = 0
                pred_depth = temp
            else:
                pred_depth = (baseline * 1000 * intrinsic[0][0] /
                              2) / (pred_disp)
                pred_depth[pred_depth == inf] = 0

            mask_depth = (gt_depth > 0.) & (gt_depth < 2000)

            if (onlyObj):
                gt_depth[sample['label'] >= 17] = 0
                mask_depth = (gt_depth > 0.) & (gt_disp < args.max_disp)

            abs = F.l1_loss(gt_depth[mask_depth],
                            pred_depth[mask_depth],
                            reduction='mean')

            mm2 = mm_error(pred_depth, gt_depth, mask_depth)
            mm4 = mm_error(pred_depth, gt_depth, mask_depth, threshold=4)
            mm8 = mm_error(pred_depth, gt_depth, mask_depth, threshold=8)

            pred_depth[pred_depth > 2000] = 0

            if (perObject):
                for x in range(left.shape[0]):
                    labels = sample['label'][x].detach().numpy().astype(
                        np.uint8)
                    for obj in np.unique(labels):
                        gtObjectDepth = gt_depth[x].detach().clone()
                        gtObjectDepth[labels != obj] = 0
                        predObjectDepth = pred_depth[x].detach().clone()
                        predObjectDepth[labels != obj] = 0

                        gtObjectDisp = gt_disp[x].detach().clone()
                        gtObjectDisp[labels != obj] = 0
                        predObjectDisp = pred_disp[x].detach().clone()
                        predObjectDisp[labels != obj] = 0

                        mask_depth = (gtObjectDepth > 0.)
                        mask_disp = (gtObjectDisp > 0.)

                        objectCount[obj] += 1

                        perObjectDisp[obj] += F.l1_loss(
                            gtObjectDisp[mask_disp],
                            predObjectDisp[mask_disp],
                            reduction='mean')
                        perObjectDepth[obj] += F.l1_loss(
                            gtObjectDepth[mask_depth],
                            predObjectDepth[mask_depth],
                            reduction='mean')

            # thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
            # thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
            # thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)

            val_epe += epe.item()
            val_d1 += d1.item()
            val_bad1 += bad1.item()
            val_bad2 += bad2.item()
            val_abs += abs.item()
            val_mm2 += mm2.item()
            val_mm4 += mm4.item()
            val_mm8 += mm8.item()
            # val_thres1 += thres1.item()
            # val_thres2 += thres2.item()
            # val_thres3 += thres3.item()

            # Save 3 images for visualization

            if i in [num_samples // 4, num_samples // 2, num_samples // 4 * 3]:
                if args.evaluate_only:

                    im = (pred_depth[0]).detach().cpu().numpy().astype(
                        np.uint16)
                    if not os.path.isdir('/cephfs/edward/depths'):
                        os.mkdir('/cephfs/edward/depths')
                    imageio.imwrite('/cephfs/edward/depths/' + str(i) + ".png",
                                    im)

                    im = (gt_depth[0]).detach().cpu().numpy().astype(np.uint16)
                    imageio.imwrite(
                        '/cephfs/edward/depths/' + str(i) + "gt.png", im)

                    imageio.imwrite(
                        '/cephfs/edward/depths/' + str(i) + "label.png",
                        sample['label'][x].detach().numpy().astype(np.uint8))

                    info = {
                        'baseline': sample['baseline'][x],
                        'intrinsic': sample['intrinsic'][x],
                        'object_ids': sample['object_ids'][x],
                        'extrinsic': sample['extrinsic'][x]
                    }
                    filename = '/cephfs/edward/depths/meta' + str(i) + '.pkl'
                    with open(filename, 'wb') as f:
                        pickle.dump(info, f)

                img_summary = {}
                img_summary['left'] = left
                img_summary['right'] = right
                img_summary['gt_depth'] = gt_depth
                img_summary['gt_disp'] = gt_disp

                if (onlyObj):
                    pred_disp[sample['label'] >= 17] = 0
                    pred_depth[sample['label'] >= 17] = 0

                img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)
                img_summary['depth_error'] = depth_error_img(
                    pred_depth, gt_depth)
                img_summary['pred_disp'] = pred_disp
                img_summary['pred_depth'] = pred_depth

                save_images(self.train_writer, 'val' + str(val_count),
                            img_summary, self.epoch)
                val_count += 1

        logger.info('=> Validation done!')
        if (perObject):
            for key, value in objectCount.items():

                perObjectDisp[key] = float(perObjectDisp[key]) / value
                perObjectDepth[key] = float(perObjectDepth[key]) / value
            print(perObjectDisp, perObjectDepth, objectCount)

        mean_epe = val_epe / valid_samples
        mean_d1 = val_d1 / valid_samples
        mean_bad1 = val_bad1 / valid_samples
        mean_bad2 = val_bad2 / valid_samples
        mean_abs = val_abs / valid_samples
        mean_mm2 = val_mm2 / valid_samples
        mean_mm4 = val_mm4 / valid_samples
        mean_mm8 = val_mm8 / valid_samples
        # mean_thres1 = val_thres1 / valid_samples
        # mean_thres2 = val_thres2 / valid_samples
        # mean_thres3 = val_thres3 / valid_samples

        # Save validation results
        with open(val_file, 'a') as f:
            f.write('epoch: %03d\t' % self.epoch)
            f.write('epe: %.4f\t' % mean_epe)
            f.write('d1: %.4f\t' % mean_d1)
            f.write('bad1: %.4f\t' % mean_bad1)
            f.write('bad2: %.4f\t' % mean_bad2)
            f.write('abs: %.4f\t' % mean_abs)
            f.write('mm2: %.4f\t' % mean_mm2)
            f.write('mm4: %.4f\t' % mean_mm4)
            f.write('mm8: %.4f\t' % mean_mm8)
            # f.write('thres1: %.4f\t' % mean_thres1)
            # f.write('thres2: %.4f\t' % mean_thres2)
            # f.write('thres3: %.4f\n' % mean_thres3)

        logger.info('=> Mean validation epe of epoch %d: %.3f' %
                    (self.epoch, mean_epe))

        self.train_writer.add_scalar('val/epe', mean_epe, self.epoch)
        self.train_writer.add_scalar('val/d1', mean_d1, self.epoch)
        self.train_writer.add_scalar('val/bad1', mean_bad1, self.epoch)
        self.train_writer.add_scalar('val/bad2', mean_bad2, self.epoch)
        self.train_writer.add_scalar('val/abs', mean_abs, self.epoch)
        self.train_writer.add_scalar('val/mm2', mean_mm2, self.epoch)
        self.train_writer.add_scalar('val/mm4', mean_mm4, self.epoch)
        self.train_writer.add_scalar('val/mm8', mean_mm8, self.epoch)
        # self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch)
        # self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch)
        # self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch)

        if not args.evaluate_only:
            if args.val_metric == 'd1':
                if mean_d1 < self.best_epe:
                    # Actually best_epe here is d1
                    self.best_epe = mean_d1
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_d1,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            elif args.val_metric == 'epe':
                if mean_epe < self.best_epe:
                    self.best_epe = mean_epe
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_epe,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            else:
                raise NotImplementedError

        if self.epoch == args.max_epoch:
            # Save best validation results
            with open(val_file, 'a') as f:
                f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

            logger.info('=> best epoch: %03d \t best %s: %.3f\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

        # Always save the latest model for resuming training
        if not args.evaluate_only:
            utils.save_checkpoint(args.checkpoint_dir,
                                  self.optimizer,
                                  self.aanet,
                                  epoch=self.epoch,
                                  num_iter=self.num_iter,
                                  epe=mean_epe,
                                  best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth')

            # Save checkpoint of specific epochs
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir,
                                      self.optimizer,
                                      self.aanet,
                                      epoch=self.epoch,
                                      num_iter=self.num_iter,
                                      epe=mean_epe,
                                      best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False)