Beispiel #1
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build 2d model
    net_2d = build_model_sem_seg_2d(cfg)[0]

    # build checkpointer
    checkpointer = CheckpointerV2(net_2d, save_dir=output_dir, logger=logger)
    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # wrapper for 2d model
    model = MVPNet2D(net_2d)
    model = model.cuda()

    # build dataset
    test_dataset = ScanNet2D3DChunksTest(
        cache_dir=args.cache_dir,
        image_dir=args.image_dir,
        split=args.split,
        chunk_size=(args.chunk_size, args.chunk_size),
        chunk_stride=args.chunk_stride,
        chunk_thresh=args.chunk_thresh,
        num_rgbd_frames=args.num_views,
        resize=cfg.DATASET.ScanNet2D.resize,
        image_normalizer=cfg.DATASET.ScanNet2D.normalizer,
        k=args.k,
        to_tensor=True,
    )
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=lambda x: x[0])

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        start_time_scan = time.time()
        for scan_idx, data_dict_list in enumerate(test_dataloader):
            # fetch data
            scan_id = test_dataset.scan_ids[scan_idx]
            points = test_dataset.data[scan_idx]['points'].astype(np.float32)
            seg_label = test_dataset.data[scan_idx]['seg_label']
            seg_label = test_dataset.nyu40_to_scannet[seg_label]
            data_time = time.time() - start_time_scan

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes],
                                              dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for data_dict in data_dict_list:
                chunk_ind = data_dict.pop('chunk_ind')
                data_batch = {
                    k: torch.tensor([v])
                    for k, v in data_dict.items()
                }
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:len(chunk_ind)]
                # update
                pred_logit_whole_scene[chunk_ind] += seg_logit
                num_pred_per_point[chunk_ind] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(
                num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning(
                    '{:s}: There are {:d} points without prediction.'.format(
                        scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene)
                visualize_labels(points, seg_label)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(
                data=data_time,
                forward_time=forward_time,
                metric_time=metric_time,
            )

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_nyu40[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
            start_time_scan = time.time()

        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Beispiel #2
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet3D(cfg.DATASET.ROOT_DIR, split=args.split)
    test_dataset.set_mapping('scannet')

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # others
    aug_list = []
    if not args.no_rot:
        aug_list.append(T.RandomRotateZ())
    transform = T.Compose([T.ToTensor(), T.Transpose()])
    use_color = args.use_color or (model.in_channels == 3)
    num_votes = args.num_votes

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(args.seed)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for scan_idx in range(len(test_dataset)):
            start_time_scan = time.time()
            # fetch data
            tic = time.time()
            data_dict = test_dataset[scan_idx]
            scan_id = data_dict['scan_id']
            points = data_dict['points']  # (n, 3)
            colors = data_dict['colors']  # (n, 3)
            seg_label = data_dict.get('seg_label', None)  # (n,)
            data_time = time.time() - tic

            # prepare inputs
            tic = time.time()
            data_list = []
            points_list = []
            colors_list = []
            ind_list = []
            for vote_ind in range(num_votes):
                if len(points) >= args.nb_pts:
                    ind = np.random.choice(len(points),
                                           size=args.nb_pts,
                                           replace=False)
                else:
                    ind = np.hstack([
                        np.arange(len(points)),
                        np.zeros(args.nb_pts - len(points))
                    ])
                points_list.append(points[ind])
                colors_list.append(colors[ind])
                ind_list.append(ind)
            for vote_ind in range(num_votes):
                data_single = {
                    'points': points_list[vote_ind],
                }
                if use_color:
                    data_single['feature'] = colors_list[vote_ind]
                data_list.append(transform(**data_single))
            data_batch = {
                k: torch.stack([x[k] for x in data_list])
                for k in data_single
            }
            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }
            preprocess_time = time.time() - tic

            # forward
            tic = time.time()
            preds = model(data_batch)
            seg_logit_batch = preds['seg_logit'].cpu().numpy()
            forward_time = time.time() - tic

            # propagate predictions and ensemble
            tic = time.time()
            pred_logit_whole_scene = np.zeros([len(points), num_classes],
                                              dtype=np.float32)
            for vote_ind in range(num_votes):
                points_per_vote = points_list[vote_ind]
                seg_logit_per_vote = seg_logit_batch[vote_ind].T
                # Propagate to nearest neighbours
                nbrs = NearestNeighbors(
                    n_neighbors=1, algorithm='ball_tree').fit(points_per_vote)
                _, nn_indices = nbrs.kneighbors(points[:, 0:3])
                pred_logit_whole_scene += seg_logit_per_vote[nn_indices[:, 0]]
            # if we use softmax, it is necessary to normalize logits
            pred_logit_whole_scene = pred_logit_whole_scene / num_votes
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)
            postprocess_time = time.time() - tic

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene, colors=colors)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               preprocess_time=preprocess_time,
                               forward_time=forward_time,
                               postprocess_time=postprocess_time,
                               metric_time=metric_time)

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_raw[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Beispiel #3
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_2d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet2D(
        cfg.DATASET.ROOT_DIR,
        split=args.split,
        subsample=None,
        to_tensor=True,
        resize=cfg.DATASET.ScanNet2D.resize,
        normalizer=cfg.DATASET.ScanNet2D.normalizer,
    )
    batch_size = args.batch_size or cfg.VAL.BATCH_SIZE
    num_workers = args.num_workers or cfg.DATALOADER.NUM_WORKERS
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 drop_last=False)

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for iteration, data_batch in enumerate(test_dataloader):
            gt_label = data_batch.get('seg_label', None)
            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }
            # forward
            preds = model(data_batch)
            pred_label = preds['seg_logit'].argmax(
                1).cpu().numpy()  # (b, h, w)
            # evaluate
            if gt_label is not None:
                gt_label = gt_label.cpu().numpy()
                evaluator.batch_update(pred_label, gt_label)
            # logging
            if args.log_period and iteration % args.log_period == 0:
                logger.info(
                    test_meters.delimiter.join([
                        '{:d}/{:d}',
                        'acc: {acc:.2f}',
                        'IoU: {iou:.2f}',
                        # '{meters}',
                    ]).format(
                        iteration,
                        len(test_dataloader),
                        acc=evaluator.overall_acc * 100.0,
                        iou=evaluator.overall_iou * 100.0,
                        # meters=str(test_meters),
                    ))
        test_time = time.time() - start_time
        # logger.info('Test {}  test time: {:.2f}s'.format(test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Beispiel #4
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet3D(cfg.DATASET.ROOT_DIR, split=args.split)
    test_dataset.set_mapping('scannet')

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # others
    transform = T.Compose(
        [T.ToTensor(), T.Pad(args.min_nb_pts),
         T.Transpose()])
    use_color = args.use_color or (model.in_channels == 3)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for scan_idx in range(len(test_dataset)):
            start_time_scan = time.time()
            # fetch data
            tic = time.time()
            data_dict = test_dataset[scan_idx]
            scan_id = data_dict['scan_id']
            points = data_dict['points']  # (n, 3)
            colors = data_dict['colors']  # (n, 3)
            seg_label = data_dict.get('seg_label', None)  # (n,)
            data_time = time.time() - tic

            # generate chunks
            tic = time.time()
            chunk_indices = scene2chunks_legacy(points,
                                                chunk_size=(args.chunk_size,
                                                            args.chunk_size),
                                                stride=args.chunk_stride,
                                                thresh=args.chunk_thresh)
            # num_chunks = len(chunk_indices)
            preprocess_time = time.time() - tic

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes],
                                              dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for indices in chunk_indices:
                chunk_points = points[indices]
                chunk_feature = colors[indices]
                chunk_num_points = len(chunk_points)
                if chunk_num_points < args.min_nb_pts:
                    print('Too few points({}) in a chunk of {}'.format(
                        chunk_num_points, scan_id))
                # if _DEBUG:
                #     # DEBUG: visualize chunk
                #     from mvpnet.utils.o3d_util import visualize_point_cloud
                #     visualize_point_cloud(chunk_points, colors=chunk_feature)
                # prepare inputs
                data_batch = {'points': chunk_points}
                if use_color:
                    data_batch['feature'] = chunk_feature
                # preprocess
                data_batch = transform(**data_batch)
                data_batch = {
                    k: torch.stack([v])
                    for k, v in data_batch.items()
                }
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:chunk_num_points]
                # update
                pred_logit_whole_scene[indices] += seg_logit
                num_pred_per_point[indices] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(
                num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning(
                    '{:s}: There are {:d} points without prediction.'.format(
                        scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene, colors=colors)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               preprocess_time=preprocess_time,
                               forward_time=forward_time,
                               metric_time=metric_time)

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_raw[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))