Example #1
0
def test_ScanNet2D3DChunksTest():
    from mvpnet.utils.o3d_util import visualize_point_cloud
    from mvpnet.utils.visualize import visualize_labels
    cache_dir = osp.join(
        '/home/jiayuan/Projects/mvpnet_private/data/ScanNet/cache_rgbd')
    image_dir = osp.join(
        '/home/jiayuan/Projects/mvpnet_private/data/ScanNet/scans_resize_160x120'
    )

    np.random.seed(0)
    dataset = ScanNet2D3DChunksTest(
        cache_dir=cache_dir,
        image_dir=image_dir,
        split='val',
        nb_pts=8192,
        num_rgbd_frames=3,
    )
    print(dataset)
    for i in range(len(dataset)):
        data_list = dataset[i]
        for data in data_list:
            points = data['points']
            seg_label = data['seg_label']
            # colors = data.get('feature', None)
            for k, v in data.items():
                if isinstance(v, np.ndarray):
                    print(k, v.shape, v.dtype)
                else:
                    print(k, v)
            visualize_labels(points, seg_label, style='scannet')

            images = data['images']
            image_xyz = data['image_xyz']
            image_mask = data['image_mask']
            image_xyz_valid = image_xyz[image_mask]
            visualize_point_cloud(image_xyz_valid)

            knn_indices = data['knn_indices']
            # visualize_point_cloud(image_xyz_valid.reshape(-1, 3)[knn_indices[:, 1]],
            #                       images[image_mask].reshape(-1, 3)[knn_indices[:, 1]],
            #                       )
            visualize_point_cloud(
                image_xyz.reshape(-1, 3)[knn_indices[:, 2]],
                images.reshape(-1, 3)[knn_indices[:, 2]],
            )
Example #2
0
def test_ScanNet3DScene():
    from mvpnet.data.transforms import Compose, CropPad, RandomRotateZ
    from mvpnet.utils.visualize import visualize_labels
    data_dir = osp.join(_CUR_DIR, '../../data/ScanNet/cache_3d')

    # transform = None
    transform = Compose([CropPad(32768), RandomRotateZ()])
    dataset = ScanNet3DScene(data_dir,
                             'val',
                             use_color=True,
                             transform=transform)
    dataset.set_mapping('scannet')
    for i in range(len(dataset)):
        data = dataset[i]
        points = data['points']
        seg_label = data['seg_label']
        colors = data.get('feature', None)
        for k, v in data.items():
            if isinstance(v, np.ndarray):
                print(k, v.shape, v.dtype)
            else:
                print(k, v)
        visualize_labels(points, seg_label, colors, style='scannet')
Example #3
0
def test_ScanNet3D():
    from mvpnet.utils.o3d_util import visualize_point_cloud
    from mvpnet.utils.visualize import visualize_labels
    data_dir = osp.join(_CUR_DIR, '../../data/ScanNet/cache_3d')
    split = 'test'
    dataset = ScanNet3D(data_dir, split)
    dataset.set_mapping('scannet')
    # dataset.set_mapping('nyu40')

    for i in range(len(dataset)):
        data = dataset[i]
        points = data['points']
        seg_label = data.get('seg_label', None)
        colors = data['colors']
        for k, v in data.items():
            if isinstance(v, np.ndarray):
                print(k, v.shape, v.dtype)
            else:
                print(k, v)
        visualize_point_cloud(points, colors, show_frame=True)
        if split != 'test':
            # visualize_labels(points, seg_label, colors, style='nyu40_raw')
            visualize_labels(points, seg_label, colors, style='scannet')
Example #4
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build mvpnet model
    model = build_model_mvpnet_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)
    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    k = args.k or cfg.DATASET.ScanNet2D3DChunks.k
    num_views = args.num_views or cfg.DATASET.ScanNet2D3DChunks.num_rgbd_frames
    test_dataset = ScanNet2D3DChunksTest(cache_dir=args.cache_dir,
                                         image_dir=args.image_dir,
                                         split=args.split,
                                         chunk_size=(args.chunk_size, args.chunk_size),
                                         chunk_stride=args.chunk_stride,
                                         chunk_thresh=args.chunk_thresh,
                                         num_rgbd_frames=num_views,
                                         resize=cfg.DATASET.ScanNet2D3DChunks.resize,
                                         image_normalizer=cfg.DATASET.ScanNet2D3DChunks.image_normalizer,
                                         k=k,
                                         to_tensor=True,
                                         )
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=lambda x: x[0])

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names) #加这个参数会报错吗?labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)
        os.makedirs(submit_dir)
        logits_dir = osp.join(submit_dir, 'logits')
        os.makedirs(logits_dir)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        start_time_scan = time.time()
        for scan_idx, data_dict_list in enumerate(test_dataloader):
            # fetch data
            scan_id = test_dataset.scan_ids[scan_idx]
            points = test_dataset.data[scan_idx]['points'].astype(np.float32)
            seg_label = None
            if args.split != 'test':
                seg_label = test_dataset.data[scan_idx]['seg_label']
                seg_label = test_dataset.nyu40_to_scannet[seg_label] #id? value?
            data_time = time.time() - start_time_scan

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes], dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for data_dict in data_dict_list:
                chunk_ind = data_dict.pop('chunk_ind')

                # padding for chunks with points less than min-nb-pts
                # It is required since farthest point sampling requires more points than centroids.
                chunk_points = data_dict['points']
                chunk_nb_pts = chunk_points.shape[1]  # note that already transposed
                if chunk_nb_pts < args.min_nb_pts:
                    print('Too sparse chunk in {} with {} points.'.format(scan_id, chunk_nb_pts))
                    pad = np.random.randint(chunk_nb_pts, size=args.min_nb_pts - chunk_nb_pts)
                    choice = np.hstack([np.arange(chunk_nb_pts), pad])
                    data_dict['points'] = data_dict['points'][:, choice]
                    data_dict['knn_indices'] = data_dict['knn_indices'][choice]

                data_batch = {k: torch.tensor([v]) for k, v in data_dict.items()}
                data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items()}
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:len(chunk_ind)]
                # update
                pred_logit_whole_scene[chunk_ind] += seg_logit
                num_pred_per_point[chunk_ind] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning('{:s}: There are {:d} points without prediction.'.format(scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene)
                visualize_labels(points, seg_label)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               forward_time=forward_time,
                               metric_time=metric_time,
                               )

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_nyu40[pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'), remapped_pred_labels, '%d')
                np.save(osp.join(logits_dir, scan_id + '.npy'), pred_logit_whole_scene, '%d')

            logger.info(
                test_meters.delimiter.join(
                    [
                        '{:d}/{:d}({:s})',
                        'acc: {acc:.2f}',
                        'IoU: {iou:.2f}',
                        '{meters}',
                    ]
                ).format(
                    scan_idx, len(test_dataset), scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                )
            )
            start_time_scan = time.time()

        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 * evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Example #5
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build 2d model
    net_2d = build_model_sem_seg_2d(cfg)[0]

    # build checkpointer
    checkpointer = CheckpointerV2(net_2d, save_dir=output_dir, logger=logger)
    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # wrapper for 2d model
    model = MVPNet2D(net_2d)
    model = model.cuda()

    # build dataset
    test_dataset = ScanNet2D3DChunksTest(
        cache_dir=args.cache_dir,
        image_dir=args.image_dir,
        split=args.split,
        chunk_size=(args.chunk_size, args.chunk_size),
        chunk_stride=args.chunk_stride,
        chunk_thresh=args.chunk_thresh,
        num_rgbd_frames=args.num_views,
        resize=cfg.DATASET.ScanNet2D.resize,
        image_normalizer=cfg.DATASET.ScanNet2D.normalizer,
        k=args.k,
        to_tensor=True,
    )
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=lambda x: x[0])

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        start_time_scan = time.time()
        for scan_idx, data_dict_list in enumerate(test_dataloader):
            # fetch data
            scan_id = test_dataset.scan_ids[scan_idx]
            points = test_dataset.data[scan_idx]['points'].astype(np.float32)
            seg_label = test_dataset.data[scan_idx]['seg_label']
            seg_label = test_dataset.nyu40_to_scannet[seg_label]
            data_time = time.time() - start_time_scan

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes],
                                              dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for data_dict in data_dict_list:
                chunk_ind = data_dict.pop('chunk_ind')
                data_batch = {
                    k: torch.tensor([v])
                    for k, v in data_dict.items()
                }
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:len(chunk_ind)]
                # update
                pred_logit_whole_scene[chunk_ind] += seg_logit
                num_pred_per_point[chunk_ind] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(
                num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning(
                    '{:s}: There are {:d} points without prediction.'.format(
                        scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene)
                visualize_labels(points, seg_label)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(
                data=data_time,
                forward_time=forward_time,
                metric_time=metric_time,
            )

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_nyu40[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
            start_time_scan = time.time()

        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Example #6
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet3D(cfg.DATASET.ROOT_DIR, split=args.split)
    test_dataset.set_mapping('scannet')

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # others
    aug_list = []
    if not args.no_rot:
        aug_list.append(T.RandomRotateZ())
    transform = T.Compose([T.ToTensor(), T.Transpose()])
    use_color = args.use_color or (model.in_channels == 3)
    num_votes = args.num_votes

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(args.seed)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for scan_idx in range(len(test_dataset)):
            start_time_scan = time.time()
            # fetch data
            tic = time.time()
            data_dict = test_dataset[scan_idx]
            scan_id = data_dict['scan_id']
            points = data_dict['points']  # (n, 3)
            colors = data_dict['colors']  # (n, 3)
            seg_label = data_dict.get('seg_label', None)  # (n,)
            data_time = time.time() - tic

            # prepare inputs
            tic = time.time()
            data_list = []
            points_list = []
            colors_list = []
            ind_list = []
            for vote_ind in range(num_votes):
                if len(points) >= args.nb_pts:
                    ind = np.random.choice(len(points),
                                           size=args.nb_pts,
                                           replace=False)
                else:
                    ind = np.hstack([
                        np.arange(len(points)),
                        np.zeros(args.nb_pts - len(points))
                    ])
                points_list.append(points[ind])
                colors_list.append(colors[ind])
                ind_list.append(ind)
            for vote_ind in range(num_votes):
                data_single = {
                    'points': points_list[vote_ind],
                }
                if use_color:
                    data_single['feature'] = colors_list[vote_ind]
                data_list.append(transform(**data_single))
            data_batch = {
                k: torch.stack([x[k] for x in data_list])
                for k in data_single
            }
            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }
            preprocess_time = time.time() - tic

            # forward
            tic = time.time()
            preds = model(data_batch)
            seg_logit_batch = preds['seg_logit'].cpu().numpy()
            forward_time = time.time() - tic

            # propagate predictions and ensemble
            tic = time.time()
            pred_logit_whole_scene = np.zeros([len(points), num_classes],
                                              dtype=np.float32)
            for vote_ind in range(num_votes):
                points_per_vote = points_list[vote_ind]
                seg_logit_per_vote = seg_logit_batch[vote_ind].T
                # Propagate to nearest neighbours
                nbrs = NearestNeighbors(
                    n_neighbors=1, algorithm='ball_tree').fit(points_per_vote)
                _, nn_indices = nbrs.kneighbors(points[:, 0:3])
                pred_logit_whole_scene += seg_logit_per_vote[nn_indices[:, 0]]
            # if we use softmax, it is necessary to normalize logits
            pred_logit_whole_scene = pred_logit_whole_scene / num_votes
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)
            postprocess_time = time.time() - tic

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene, colors=colors)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               preprocess_time=preprocess_time,
                               forward_time=forward_time,
                               postprocess_time=postprocess_time,
                               metric_time=metric_time)

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_raw[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Example #7
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet3D(cfg.DATASET.ROOT_DIR, split=args.split)
    test_dataset.set_mapping('scannet')

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # others
    transform = T.Compose(
        [T.ToTensor(), T.Pad(args.min_nb_pts),
         T.Transpose()])
    use_color = args.use_color or (model.in_channels == 3)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for scan_idx in range(len(test_dataset)):
            start_time_scan = time.time()
            # fetch data
            tic = time.time()
            data_dict = test_dataset[scan_idx]
            scan_id = data_dict['scan_id']
            points = data_dict['points']  # (n, 3)
            colors = data_dict['colors']  # (n, 3)
            seg_label = data_dict.get('seg_label', None)  # (n,)
            data_time = time.time() - tic

            # generate chunks
            tic = time.time()
            chunk_indices = scene2chunks_legacy(points,
                                                chunk_size=(args.chunk_size,
                                                            args.chunk_size),
                                                stride=args.chunk_stride,
                                                thresh=args.chunk_thresh)
            # num_chunks = len(chunk_indices)
            preprocess_time = time.time() - tic

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes],
                                              dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for indices in chunk_indices:
                chunk_points = points[indices]
                chunk_feature = colors[indices]
                chunk_num_points = len(chunk_points)
                if chunk_num_points < args.min_nb_pts:
                    print('Too few points({}) in a chunk of {}'.format(
                        chunk_num_points, scan_id))
                # if _DEBUG:
                #     # DEBUG: visualize chunk
                #     from mvpnet.utils.o3d_util import visualize_point_cloud
                #     visualize_point_cloud(chunk_points, colors=chunk_feature)
                # prepare inputs
                data_batch = {'points': chunk_points}
                if use_color:
                    data_batch['feature'] = chunk_feature
                # preprocess
                data_batch = transform(**data_batch)
                data_batch = {
                    k: torch.stack([v])
                    for k, v in data_batch.items()
                }
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:chunk_num_points]
                # update
                pred_logit_whole_scene[indices] += seg_logit
                num_pred_per_point[indices] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(
                num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning(
                    '{:s}: There are {:d} points without prediction.'.format(
                        scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene, colors=colors)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               preprocess_time=preprocess_time,
                               forward_time=forward_time,
                               metric_time=metric_time)

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_raw[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))