コード例 #1
0
    def __init__(self, hparams):
        super(MVSSystem, self).__init__()
        self.hparams = hparams
        # to unnormalize image for visualization
        self.unpreprocess = T.Normalize(
            mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
            std=[1 / 0.229, 1 / 0.224, 1 / 0.225])

        self.loss = loss_dict[hparams.loss_type](hparams.levels)

        self.model = CascadeMVSNet(
            n_depths=self.hparams.n_depths,
            interval_ratios=self.hparams.interval_ratios,
            num_groups=self.hparams.num_groups,
            norm_act=InPlaceABN)

        # if num gpu is 1, print model structure and number of params
        if self.hparams.num_gpus == 1:
            # print(self.model)
            print('number of parameters : %.2f M' %
                  (sum(p.numel()
                       for p in self.model.parameters() if p.requires_grad) /
                   1e6))

        # load model if checkpoint path is provided
        if self.hparams.ckpt_path != '':
            print('Load model from', self.hparams.ckpt_path)
            load_ckpt(self.model, self.hparams.ckpt_path,
                      self.hparams.prefixes_to_ignore)
コード例 #2
0
class MVSSystem(LightningModule):
    def __init__(self, hparams):
        super(MVSSystem, self).__init__()
        self.hparams = hparams
        # to unnormalize image for visualization
        self.unpreprocess = T.Normalize(
            mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
            std=[1 / 0.229, 1 / 0.224, 1 / 0.225])

        self.loss = loss_dict[hparams.loss_type](hparams.levels)

        self.model = CascadeMVSNet(
            n_depths=self.hparams.n_depths,
            interval_ratios=self.hparams.interval_ratios,
            num_groups=self.hparams.num_groups,
            norm_act=InPlaceABN)

        # if num gpu is 1, print model structure and number of params
        if self.hparams.num_gpus == 1:
            # print(self.model)
            print('number of parameters : %.2f M' %
                  (sum(p.numel()
                       for p in self.model.parameters() if p.requires_grad) /
                   1e6))

        # load model if checkpoint path is provided
        if self.hparams.ckpt_path != '':
            print('Load model from', self.hparams.ckpt_path)
            load_ckpt(self.model, self.hparams.ckpt_path,
                      self.hparams.prefixes_to_ignore)

    def decode_batch(self, batch):
        imgs = batch['imgs']
        proj_mats = batch['proj_mats']
        depths = batch['depths']
        masks = batch['masks']
        init_depth_min = batch['init_depth_min']
        depth_interval = batch['depth_interval']
        return imgs, proj_mats, depths, masks, init_depth_min, depth_interval

    def forward(self, imgs, proj_mats, init_depth_min, depth_interval):
        return self.model(imgs, proj_mats, init_depth_min, depth_interval)

    def prepare_data(self):
        dataset = dataset_dict[self.hparams.dataset_name]
        self.train_dataset = dataset(
            root_dir=self.hparams.root_dir,
            split='train',
            n_views=self.hparams.n_views,
            levels=self.hparams.levels,
            depth_interval=self.hparams.depth_interval)
        self.val_dataset = dataset(root_dir=self.hparams.root_dir,
                                   split='val',
                                   n_views=self.hparams.n_views,
                                   levels=self.hparams.levels,
                                   depth_interval=self.hparams.depth_interval)

    def configure_optimizers(self):
        self.optimizer = get_optimizer(self.hparams, self.model)
        scheduler = get_scheduler(self.hparams, self.optimizer)

        return [self.optimizer], [scheduler]

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          shuffle=True,
                          num_workers=4,
                          batch_size=self.hparams.batch_size,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          shuffle=False,
                          num_workers=4,
                          batch_size=self.hparams.batch_size,
                          pin_memory=True)

    def training_step(self, batch, batch_nb):
        log = {'lr': get_learning_rate(self.optimizer)}
        imgs, proj_mats, depths, masks, init_depth_min, depth_interval = \
            self.decode_batch(batch)
        results = self(imgs, proj_mats, init_depth_min, depth_interval)
        log['train/loss'] = loss = self.loss(results, depths, masks)

        with torch.no_grad():
            if batch_nb == 0:
                img_ = self.unpreprocess(imgs[0,
                                              0]).cpu()  # batch 0, ref image
                depth_gt_ = visualize_depth(depths['level_0'][0])
                depth_pred_ = visualize_depth(results['depth_0'][0] *
                                              masks['level_0'][0])
                prob = visualize_prob(results['confidence_0'][0] *
                                      masks['level_0'][0])
                stack = torch.stack([img_, depth_gt_, depth_pred_,
                                     prob])  # (4, 3, H, W)
                self.logger.experiment.add_images('train/image_GT_pred_prob',
                                                  stack, self.global_step)

            depth_pred = results['depth_0']
            depth_gt = depths['level_0']
            mask = masks['level_0']
            log['train/abs_err'] = abs_err = abs_error(depth_pred, depth_gt,
                                                       mask).mean()
            log['train/acc_1mm'] = acc_threshold(depth_pred, depth_gt, mask,
                                                 1).mean()
            log['train/acc_2mm'] = acc_threshold(depth_pred, depth_gt, mask,
                                                 2).mean()
            log['train/acc_4mm'] = acc_threshold(depth_pred, depth_gt, mask,
                                                 4).mean()

        return {
            'loss': loss,
            'progress_bar': {
                'train_abs_err': abs_err
            },
            'log': log
        }

    def validation_step(self, batch, batch_nb):
        log = {}
        imgs, proj_mats, depths, masks, init_depth_min, depth_interval = \
            self.decode_batch(batch)
        results = self(imgs, proj_mats, init_depth_min, depth_interval)
        log['val_loss'] = self.loss(results, depths, masks)

        if batch_nb == 0:
            img_ = self.unpreprocess(imgs[0, 0]).cpu()  # batch 0, ref image
            depth_gt_ = visualize_depth(depths['level_0'][0])
            depth_pred_ = visualize_depth(results['depth_0'][0] *
                                          masks['level_0'][0])
            prob = visualize_prob(results['confidence_0'][0] *
                                  masks['level_0'][0])
            stack = torch.stack([img_, depth_gt_, depth_pred_,
                                 prob])  # (4, 3, H, W)
            self.logger.experiment.add_images('val/image_GT_pred_prob', stack,
                                              self.global_step)

        depth_pred = results['depth_0']
        depth_gt = depths['level_0']
        mask = masks['level_0']

        log['val_abs_err'] = abs_error(depth_pred, depth_gt, mask).sum()
        log['val_acc_1mm'] = acc_threshold(depth_pred, depth_gt, mask, 1).sum()
        log['val_acc_2mm'] = acc_threshold(depth_pred, depth_gt, mask, 2).sum()
        log['val_acc_4mm'] = acc_threshold(depth_pred, depth_gt, mask, 4).sum()
        log['mask_sum'] = mask.float().sum()

        return log

    def validation_epoch_end(self, outputs):
        mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        mask_sum = torch.stack([x['mask_sum'] for x in outputs]).sum()
        mean_abs_err = torch.stack([x['val_abs_err']
                                    for x in outputs]).sum() / mask_sum
        mean_acc_1mm = torch.stack([x['val_acc_1mm']
                                    for x in outputs]).sum() / mask_sum
        mean_acc_2mm = torch.stack([x['val_acc_2mm']
                                    for x in outputs]).sum() / mask_sum
        mean_acc_4mm = torch.stack([x['val_acc_4mm']
                                    for x in outputs]).sum() / mask_sum

        return {
            'progress_bar': {
                'val_loss': mean_loss,
                'val_abs_err': mean_abs_err
            },
            'log': {
                'val/loss': mean_loss,
                'val/abs_err': mean_abs_err,
                'val/acc_1mm': mean_acc_1mm,
                'val/acc_2mm': mean_acc_2mm,
                'val/acc_4mm': mean_acc_4mm,
            }
        }
コード例 #3
0
if __name__ == "__main__":
    args = get_opts()
    dataset = dataset_dict[args.dataset_name] \
                (args.root_dir, args.split,
                 n_views=args.n_views, depth_interval=args.depth_interval,
                 img_wh=tuple(args.img_wh))

    if args.scan:
        scans = [args.scan]
    else: # evaluate on all scans in dataset
        scans = dataset.scans

    # Step 1. Create depth estimation and probability for each scan
    model = CascadeMVSNet(n_depths=args.n_depths,
                          interval_ratios=args.interval_ratios,
                          num_groups=args.num_groups,
                          norm_act=ABN)
    device = 'cpu' if args.cpu else 'cuda:0'
    model.to(device)
    load_ckpt(model, args.ckpt_path)
    model.eval()

    depth_dir = f'results/{args.dataset_name}/depth'
    print('Creating depth and confidence predictions...')
    if args.scan: # TODO: adapt scan specification to tanks and blendedmvs
        data_range = [i for i, x in enumerate(dataset.metas) if x[0] == args.scan]
    else:
        data_range = range(len(dataset))
    for i in tqdm(data_range):
        imgs, proj_mats, init_depth_min, depth_interval, \
            scan, vid = decode_batch(dataset[i])