Exemplo n.º 1
0
def build_part_seg_dataset(cfg, mode='train'):
    split = cfg.DATASET[mode.upper()]
    is_train = (mode == 'train')

    transform_list = parse_augmentations(cfg, is_train)
    transform_list.insert(0, T.ToTensor())
    transform_list.append(T.Transpose())
    transform = T.ComposeSeg(transform_list)

    if cfg.DATASET.TYPE == 'ShapeNetPartH5':
        dataset = D.ShapeNetPartH5(root_dir=cfg.DATASET.ROOT_DIR,
                                   split=split,
                                   transform=transform,
                                   num_points=cfg.INPUT.NUM_POINTS,
                                   load_seg=True)
    elif cfg.DATASET.TYPE == 'ShapeNetPart':
        dataset = D.ShapeNetPart(root_dir=cfg.DATASET.ROOT_DIR,
                                 split=split,
                                 transform=transform,
                                 num_points=cfg.INPUT.NUM_POINTS,
                                 load_seg=True)
    elif cfg.DATASET.TYPE == 'ShapeNetPartNormal':
        assert cfg.INPUT.USE_NORMAL
        dataset = D.ShapeNetPartNormal(root_dir=cfg.DATASET.ROOT_DIR,
                                       split=split,
                                       transform=transform,
                                       num_points=cfg.INPUT.NUM_POINTS,
                                       with_normal=True,
                                       load_seg=True)
    else:
        raise ValueError('Unsupported type of dataset: {}.'.format(
            cfg.DATASET.TYPE))

    return dataset
Exemplo n.º 2
0
def build_cls_dataset(cfg, mode='train'):
    split = cfg.DATASET[mode.upper()]
    is_train = (mode == 'train')

    transform_list = parse_augmentations(cfg, is_train)
    transform_list.insert(0, T.ToTensor())
    transform_list.append(T.Transpose())
    transform = T.Compose(transform_list)

    if cfg.DATASET.TYPE == 'ModelNet40H5':
        dataset = D.ModelNet40H5(root_dir=cfg.DATASET.ROOT_DIR,
                                 split=split,
                                 num_points=cfg.INPUT.NUM_POINTS,
                                 transform=transform)
    elif cfg.DATASET.TYPE == 'ModelNet40':
        dataset = D.ModelNet40(root_dir=cfg.DATASET.ROOT_DIR,
                               split=split,
                               num_points=cfg.INPUT.NUM_POINTS,
                               normalize=True,
                               with_normal=cfg.INPUT.USE_NORMAL,
                               transform=transform)
    else:
        raise ValueError('Unsupported type of dataset: {}.'.format(
            cfg.DATASET.TYPE))

    return dataset
Exemplo n.º 3
0
def build_ins_seg_3d_dataset(cfg, mode='train'):
    is_train = (mode == 'train')

    augmentations = cfg.TRAIN.AUGMENTATION if is_train else cfg.TEST.AUGMENTATION
    transform_list = parse_augmentations(augmentations)
    transform_list.insert(0, T.ToTensor())
    transform_list.append(T.Transpose())
    transform = T.ComposeSeg(transform_list)

    kwargs_dict = cfg.DATASET[cfg.DATASET.TYPE].get(mode.upper(), dict())

    if cfg.DATASET.TYPE == 'PartNetInsSeg':
        dataset = PartNetInsSeg(root_dir=cfg.DATASET.ROOT_DIR,
                                transform=transform,
                                **kwargs_dict)
    elif cfg.DATASET.TYPE == 'PartNetRegionInsSeg':
        dataset = PartNetRegionInsSeg(root_dir=cfg.DATASET.ROOT_DIR,
                                      transform=transform,
                                      **kwargs_dict)
    else:
        raise ValueError('Unsupported type of dataset.')

    return dataset
Exemplo n.º 4
0
def test(cfg, output_dir=''):
    logger = logging.getLogger('shaper.test')

    # build model
    model, loss_fn, metric = build_model(cfg)
    model = nn.DataParallel(model).cuda()
    # model = model.cuda()

    # build checkpointer
    checkpointer = Checkpointer(model, save_dir=output_dir, logger=logger)

    if cfg.TEST.WEIGHT:
        # load weight if specified
        weight_path = cfg.TEST.WEIGHT.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build data loader
    test_dataloader = build_dataloader(cfg, mode='test')
    test_dataset = test_dataloader.dataset

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    loss_fn.eval()
    metric.eval()
    set_random_seed(cfg.RNG_SEED)
    evaluator = Evaluator(test_dataset.class_names)

    if cfg.TEST.VOTE.NUM_VOTE > 1:
        # remove old transform
        test_dataset.transform = None
        if cfg.TEST.VOTE.TYPE == 'AUGMENTATION':
            tmp_cfg = cfg.clone()
            tmp_cfg.defrost()
            tmp_cfg.TEST.AUGMENTATION = tmp_cfg.TEST.VOTE.AUGMENTATION
            transform = T.Compose([T.ToTensor()] +
                                  parse_augmentations(tmp_cfg, False) +
                                  [T.Transpose()])
            transform_list = [transform] * cfg.TEST.VOTE.NUM_VOTE
        elif cfg.TEST.VOTE.TYPE == 'MULTI_VIEW':
            # build new transform
            transform_list = []
            for view_ind in range(cfg.TEST.VOTE.NUM_VOTE):
                aug_type = T.RotateByAngleWithNormal if cfg.INPUT.USE_NORMAL else T.RotateByAngle
                rotate_by_angle = aug_type(
                    cfg.TEST.VOTE.MULTI_VIEW.AXIS,
                    2 * np.pi * view_ind / cfg.TEST.VOTE.NUM_VOTE)
                t = [T.ToTensor(), rotate_by_angle, T.Transpose()]
                if cfg.TEST.VOTE.MULTI_VIEW.SHUFFLE:
                    # Some non-deterministic algorithms, like PointNet++, benefit from shuffle.
                    t.insert(-1, T.Shuffle())
                transform_list.append(T.Compose(t))
        else:
            raise NotImplementedError('Unsupported voting method.')

        with torch.no_grad():
            tmp_dataloader = DataLoader(test_dataset,
                                        num_workers=1,
                                        collate_fn=lambda x: x[0])
            start_time = time.time()
            end = start_time
            for ind, data in enumerate(tmp_dataloader):
                data_time = time.time() - end
                points = data['points']

                # convert points into tensor
                points_batch = [t(points.copy()) for t in transform_list]
                points_batch = torch.stack(points_batch, dim=0)
                points_batch = points_batch.cuda(non_blocking=True)

                preds = model({'points': points_batch})
                cls_logit_batch = preds['cls_logit'].cpu().numpy(
                )  # (batch_size, num_classes)
                cls_logit_ensemble = np.mean(cls_logit_batch, axis=0)
                pred_label = np.argmax(cls_logit_ensemble)
                evaluator.update(pred_label, data['cls_label'])

                batch_time = time.time() - end
                end = time.time()

                if cfg.TEST.LOG_PERIOD > 0 and ind % cfg.TEST.LOG_PERIOD == 0:
                    logger.info('iter: {:4d}  time:{:.4f}  data:{:.4f}'.format(
                        ind, batch_time, data_time))
        test_time = time.time() - start_time
        logger.info('Test total time: {:.2f}s'.format(test_time))
    else:
        test_meters = MetricLogger(delimiter='  ')
        test_meters.bind(metric)
        with torch.no_grad():
            start_time = time.time()
            end = start_time
            for iteration, data_batch in enumerate(test_dataloader):
                data_time = time.time() - end

                cls_label_batch = data_batch['cls_label'].numpy()
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }

                preds = model(data_batch)

                loss_dict = loss_fn(preds, data_batch)
                total_loss = sum(loss_dict.values())

                test_meters.update(loss=total_loss, **loss_dict)
                metric.update_dict(preds, data_batch)

                cls_logit_batch = preds['cls_logit'].cpu().numpy(
                )  # (batch_size, num_classes)
                pred_label_batch = np.argmax(cls_logit_batch, axis=1)
                evaluator.batch_update(pred_label_batch, cls_label_batch)

                batch_time = time.time() - end
                end = time.time()
                test_meters.update(time=batch_time, data=data_time)

                if cfg.TEST.LOG_PERIOD > 0 and iteration % cfg.TEST.LOG_PERIOD == 0:
                    logger.info(
                        test_meters.delimiter.join([
                            'iter: {iter:4d}',
                            '{meters}',
                        ]).format(
                            iter=iteration,
                            meters=str(test_meters),
                        ))
        test_time = time.time() - start_time
        logger.info('Test {}  total time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_accuracy))
    logger.info('average class accuracy={:.2f}%.\n{}'.format(
        100.0 * np.nanmean(evaluator.class_accuracy), evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.cls.tsv'))