コード例 #1
0
ファイル: test_mvpnet_3d.py プロジェクト: dcy0577/Thesis_repo
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build mvpnet model
    model = build_model_mvpnet_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)
    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    k = args.k or cfg.DATASET.ScanNet2D3DChunks.k
    num_views = args.num_views or cfg.DATASET.ScanNet2D3DChunks.num_rgbd_frames
    test_dataset = ScanNet2D3DChunksTest(cache_dir=args.cache_dir,
                                         image_dir=args.image_dir,
                                         split=args.split,
                                         chunk_size=(args.chunk_size, args.chunk_size),
                                         chunk_stride=args.chunk_stride,
                                         chunk_thresh=args.chunk_thresh,
                                         num_rgbd_frames=num_views,
                                         resize=cfg.DATASET.ScanNet2D3DChunks.resize,
                                         image_normalizer=cfg.DATASET.ScanNet2D3DChunks.image_normalizer,
                                         k=k,
                                         to_tensor=True,
                                         )
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=lambda x: x[0])

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names) #加这个参数会报错吗?labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)
        os.makedirs(submit_dir)
        logits_dir = osp.join(submit_dir, 'logits')
        os.makedirs(logits_dir)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        start_time_scan = time.time()
        for scan_idx, data_dict_list in enumerate(test_dataloader):
            # fetch data
            scan_id = test_dataset.scan_ids[scan_idx]
            points = test_dataset.data[scan_idx]['points'].astype(np.float32)
            seg_label = None
            if args.split != 'test':
                seg_label = test_dataset.data[scan_idx]['seg_label']
                seg_label = test_dataset.nyu40_to_scannet[seg_label] #id? value?
            data_time = time.time() - start_time_scan

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes], dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for data_dict in data_dict_list:
                chunk_ind = data_dict.pop('chunk_ind')

                # padding for chunks with points less than min-nb-pts
                # It is required since farthest point sampling requires more points than centroids.
                chunk_points = data_dict['points']
                chunk_nb_pts = chunk_points.shape[1]  # note that already transposed
                if chunk_nb_pts < args.min_nb_pts:
                    print('Too sparse chunk in {} with {} points.'.format(scan_id, chunk_nb_pts))
                    pad = np.random.randint(chunk_nb_pts, size=args.min_nb_pts - chunk_nb_pts)
                    choice = np.hstack([np.arange(chunk_nb_pts), pad])
                    data_dict['points'] = data_dict['points'][:, choice]
                    data_dict['knn_indices'] = data_dict['knn_indices'][choice]

                data_batch = {k: torch.tensor([v]) for k, v in data_dict.items()}
                data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items()}
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:len(chunk_ind)]
                # update
                pred_logit_whole_scene[chunk_ind] += seg_logit
                num_pred_per_point[chunk_ind] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning('{:s}: There are {:d} points without prediction.'.format(scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene)
                visualize_labels(points, seg_label)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               forward_time=forward_time,
                               metric_time=metric_time,
                               )

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_nyu40[pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'), remapped_pred_labels, '%d')
                np.save(osp.join(logits_dir, scan_id + '.npy'), pred_logit_whole_scene, '%d')

            logger.info(
                test_meters.delimiter.join(
                    [
                        '{:d}/{:d}({:s})',
                        'acc: {acc:.2f}',
                        'IoU: {iou:.2f}',
                        '{meters}',
                    ]
                ).format(
                    scan_idx, len(test_dataset), scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                )
            )
            start_time_scan = time.time()

        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 * evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
コード例 #2
0
def ensemble(run_name, split='test'):
    output_dir = '/home/docker_user/workspace/mvpnet_private/outputs/scannet/'
    submit_dir0 = osp.join(
        output_dir,
        'mvpnet_3d_pn2ssg_unet_resnet34_v2_160x120_3views_3nn_adam_log_weights_use_2d_log_weights_training_cotrain/submit/12-17_08-35-45.rits-computervision-salsa_5views/logits/'
    )
    submit_dir1 = osp.join(
        output_dir,
        'mvpnet_3d_pn2ssg_unet_resnet34_v2_160x120_3views_3nn_adam_log_weights_use_2d_log_weights_training/submit/12-16_22-04-41.rits-computervision-salsa_5views/logits/'
    )
    submit_dir2 = osp.join(
        output_dir,
        'mvpnet_3d_pn2ssg_unet_resnet34_v2_160x120_5views_3nn_adam_log_weights_use_2d_log_weights_training/submit/12-18_17-13-56.rits-computervision-salsa/logits/'
    )
    submit_dir3 = osp.join(
        output_dir,
        'mvpnet_3d_pn2ssg_unet_resnet34_v2_160x120_5views_3nn_adam/submit/12-18_17-14-05.rits-computervision-salsa/logits/'
    )
    ensemble_save_dir = osp.join(output_dir, 'ensemble', run_name)
    os.makedirs(ensemble_save_dir)

    dataset = None
    if split == 'val':
        dataset = ScanNet2D3DChunks(
            '/home/docker_user/workspace/mvpnet_private/data/ScanNet/cache_rgbd',
            '', 'val')
        data = sorted(dataset.data, key=lambda k: k['scan_id'])
        evaluator = Evaluator(dataset.class_names)

    logit_fnames0 = sorted(os.listdir(submit_dir0))
    logit_fnames1 = sorted(os.listdir(submit_dir1))

    assert logit_fnames0 == logit_fnames1
    for i, fname in enumerate(logit_fnames0):
        scan_id, _ = osp.splitext(fname)
        print('{}/{}: {}'.format(i + 1, len(logit_fnames0), scan_id))
        pred_logits_whole_scene0 = np.load(osp.join(submit_dir0, fname))
        pred_logits_whole_scene1 = np.load(osp.join(submit_dir1, fname))
        pred_logits_whole_scene2 = np.load(osp.join(submit_dir2, fname))
        pred_logits_whole_scene3 = np.load(osp.join(submit_dir3, fname))
        pred_logits_whole_scene = scipy.special.softmax(pred_logits_whole_scene0, axis=1) + \
                                  scipy.special.softmax(pred_logits_whole_scene1, axis=1) + \
                                  scipy.special.softmax(pred_logits_whole_scene2, axis=1) + \
                                  scipy.special.softmax(pred_logits_whole_scene3, axis=1)
        pred_labels_whole_scene = pred_logits_whole_scene.argmax(1)

        if dataset is not None:
            seg_label = data[i]['seg_label']
            seg_label = dataset.nyu40_to_scannet[seg_label]
            evaluator.update(pred_labels_whole_scene, seg_label)

        # save to txt file for submission
        remapped_pred_labels = scannet_to_nyu40[pred_labels_whole_scene]
        np.savetxt(osp.join(ensemble_save_dir, scan_id + '.txt'),
                   remapped_pred_labels, '%d')

    if dataset is not None:
        print('overall accuracy={:.2f}%'.format(100.0 * evaluator.overall_acc))
        print('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
        print('class-wise accuracy and IoU.\n{}'.format(
            evaluator.print_table()))
        evaluator.save_table(
            osp.join(ensemble_save_dir, 'eval.{}.tsv'.format(run_name)))
コード例 #3
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build 2d model
    net_2d = build_model_sem_seg_2d(cfg)[0]

    # build checkpointer
    checkpointer = CheckpointerV2(net_2d, save_dir=output_dir, logger=logger)
    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # wrapper for 2d model
    model = MVPNet2D(net_2d)
    model = model.cuda()

    # build dataset
    test_dataset = ScanNet2D3DChunksTest(
        cache_dir=args.cache_dir,
        image_dir=args.image_dir,
        split=args.split,
        chunk_size=(args.chunk_size, args.chunk_size),
        chunk_stride=args.chunk_stride,
        chunk_thresh=args.chunk_thresh,
        num_rgbd_frames=args.num_views,
        resize=cfg.DATASET.ScanNet2D.resize,
        image_normalizer=cfg.DATASET.ScanNet2D.normalizer,
        k=args.k,
        to_tensor=True,
    )
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=lambda x: x[0])

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        start_time_scan = time.time()
        for scan_idx, data_dict_list in enumerate(test_dataloader):
            # fetch data
            scan_id = test_dataset.scan_ids[scan_idx]
            points = test_dataset.data[scan_idx]['points'].astype(np.float32)
            seg_label = test_dataset.data[scan_idx]['seg_label']
            seg_label = test_dataset.nyu40_to_scannet[seg_label]
            data_time = time.time() - start_time_scan

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes],
                                              dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for data_dict in data_dict_list:
                chunk_ind = data_dict.pop('chunk_ind')
                data_batch = {
                    k: torch.tensor([v])
                    for k, v in data_dict.items()
                }
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:len(chunk_ind)]
                # update
                pred_logit_whole_scene[chunk_ind] += seg_logit
                num_pred_per_point[chunk_ind] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(
                num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning(
                    '{:s}: There are {:d} points without prediction.'.format(
                        scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene)
                visualize_labels(points, seg_label)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(
                data=data_time,
                forward_time=forward_time,
                metric_time=metric_time,
            )

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_nyu40[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
            start_time_scan = time.time()

        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
コード例 #4
0
ファイル: test_2d.py プロジェクト: dcy0577/Thesis_repo
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_2d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet2D(
        cfg.DATASET.ROOT_DIR,
        split=args.split,
        subsample=None,
        to_tensor=True,
        resize=cfg.DATASET.ScanNet2D.resize,
        normalizer=cfg.DATASET.ScanNet2D.normalizer,
    )
    batch_size = args.batch_size or cfg.VAL.BATCH_SIZE
    num_workers = args.num_workers or cfg.DATALOADER.NUM_WORKERS
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 drop_last=False)

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for iteration, data_batch in enumerate(test_dataloader):
            gt_label = data_batch.get('seg_label', None)
            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }
            # forward
            preds = model(data_batch)
            pred_label = preds['seg_logit'].argmax(
                1).cpu().numpy()  # (b, h, w)
            # evaluate
            if gt_label is not None:
                gt_label = gt_label.cpu().numpy()
                evaluator.batch_update(pred_label, gt_label)
            # logging
            if args.log_period and iteration % args.log_period == 0:
                logger.info(
                    test_meters.delimiter.join([
                        '{:d}/{:d}',
                        'acc: {acc:.2f}',
                        'IoU: {iou:.2f}',
                        # '{meters}',
                    ]).format(
                        iteration,
                        len(test_dataloader),
                        acc=evaluator.overall_acc * 100.0,
                        iou=evaluator.overall_iou * 100.0,
                        # meters=str(test_meters),
                    ))
        test_time = time.time() - start_time
        # logger.info('Test {}  test time: {:.2f}s'.format(test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
コード例 #5
0
ファイル: test_3d_scene.py プロジェクト: zebrajack/mvpnet
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet3D(cfg.DATASET.ROOT_DIR, split=args.split)
    test_dataset.set_mapping('scannet')

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # others
    aug_list = []
    if not args.no_rot:
        aug_list.append(T.RandomRotateZ())
    transform = T.Compose([T.ToTensor(), T.Transpose()])
    use_color = args.use_color or (model.in_channels == 3)
    num_votes = args.num_votes

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(args.seed)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for scan_idx in range(len(test_dataset)):
            start_time_scan = time.time()
            # fetch data
            tic = time.time()
            data_dict = test_dataset[scan_idx]
            scan_id = data_dict['scan_id']
            points = data_dict['points']  # (n, 3)
            colors = data_dict['colors']  # (n, 3)
            seg_label = data_dict.get('seg_label', None)  # (n,)
            data_time = time.time() - tic

            # prepare inputs
            tic = time.time()
            data_list = []
            points_list = []
            colors_list = []
            ind_list = []
            for vote_ind in range(num_votes):
                if len(points) >= args.nb_pts:
                    ind = np.random.choice(len(points),
                                           size=args.nb_pts,
                                           replace=False)
                else:
                    ind = np.hstack([
                        np.arange(len(points)),
                        np.zeros(args.nb_pts - len(points))
                    ])
                points_list.append(points[ind])
                colors_list.append(colors[ind])
                ind_list.append(ind)
            for vote_ind in range(num_votes):
                data_single = {
                    'points': points_list[vote_ind],
                }
                if use_color:
                    data_single['feature'] = colors_list[vote_ind]
                data_list.append(transform(**data_single))
            data_batch = {
                k: torch.stack([x[k] for x in data_list])
                for k in data_single
            }
            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }
            preprocess_time = time.time() - tic

            # forward
            tic = time.time()
            preds = model(data_batch)
            seg_logit_batch = preds['seg_logit'].cpu().numpy()
            forward_time = time.time() - tic

            # propagate predictions and ensemble
            tic = time.time()
            pred_logit_whole_scene = np.zeros([len(points), num_classes],
                                              dtype=np.float32)
            for vote_ind in range(num_votes):
                points_per_vote = points_list[vote_ind]
                seg_logit_per_vote = seg_logit_batch[vote_ind].T
                # Propagate to nearest neighbours
                nbrs = NearestNeighbors(
                    n_neighbors=1, algorithm='ball_tree').fit(points_per_vote)
                _, nn_indices = nbrs.kneighbors(points[:, 0:3])
                pred_logit_whole_scene += seg_logit_per_vote[nn_indices[:, 0]]
            # if we use softmax, it is necessary to normalize logits
            pred_logit_whole_scene = pred_logit_whole_scene / num_votes
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)
            postprocess_time = time.time() - tic

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene, colors=colors)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               preprocess_time=preprocess_time,
                               forward_time=forward_time,
                               postprocess_time=postprocess_time,
                               metric_time=metric_time)

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_raw[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
コード例 #6
0
ファイル: test_3d_chunks.py プロジェクト: zebrajack/mvpnet
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet3D(cfg.DATASET.ROOT_DIR, split=args.split)
    test_dataset.set_mapping('scannet')

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # others
    transform = T.Compose(
        [T.ToTensor(), T.Pad(args.min_nb_pts),
         T.Transpose()])
    use_color = args.use_color or (model.in_channels == 3)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for scan_idx in range(len(test_dataset)):
            start_time_scan = time.time()
            # fetch data
            tic = time.time()
            data_dict = test_dataset[scan_idx]
            scan_id = data_dict['scan_id']
            points = data_dict['points']  # (n, 3)
            colors = data_dict['colors']  # (n, 3)
            seg_label = data_dict.get('seg_label', None)  # (n,)
            data_time = time.time() - tic

            # generate chunks
            tic = time.time()
            chunk_indices = scene2chunks_legacy(points,
                                                chunk_size=(args.chunk_size,
                                                            args.chunk_size),
                                                stride=args.chunk_stride,
                                                thresh=args.chunk_thresh)
            # num_chunks = len(chunk_indices)
            preprocess_time = time.time() - tic

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes],
                                              dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for indices in chunk_indices:
                chunk_points = points[indices]
                chunk_feature = colors[indices]
                chunk_num_points = len(chunk_points)
                if chunk_num_points < args.min_nb_pts:
                    print('Too few points({}) in a chunk of {}'.format(
                        chunk_num_points, scan_id))
                # if _DEBUG:
                #     # DEBUG: visualize chunk
                #     from mvpnet.utils.o3d_util import visualize_point_cloud
                #     visualize_point_cloud(chunk_points, colors=chunk_feature)
                # prepare inputs
                data_batch = {'points': chunk_points}
                if use_color:
                    data_batch['feature'] = chunk_feature
                # preprocess
                data_batch = transform(**data_batch)
                data_batch = {
                    k: torch.stack([v])
                    for k, v in data_batch.items()
                }
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:chunk_num_points]
                # update
                pred_logit_whole_scene[indices] += seg_logit
                num_pred_per_point[indices] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(
                num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning(
                    '{:s}: There are {:d} points without prediction.'.format(
                        scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene, colors=colors)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               preprocess_time=preprocess_time,
                               forward_time=forward_time,
                               metric_time=metric_time)

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_raw[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))