def build_part_seg_dataset(cfg, mode='train'): split = cfg.DATASET[mode.upper()] is_train = (mode == 'train') transform_list = parse_augmentations(cfg, is_train) transform_list.insert(0, T.ToTensor()) transform_list.append(T.Transpose()) transform = T.ComposeSeg(transform_list) if cfg.DATASET.TYPE == 'ShapeNetPartH5': dataset = D.ShapeNetPartH5(root_dir=cfg.DATASET.ROOT_DIR, split=split, transform=transform, num_points=cfg.INPUT.NUM_POINTS, load_seg=True) elif cfg.DATASET.TYPE == 'ShapeNetPart': dataset = D.ShapeNetPart(root_dir=cfg.DATASET.ROOT_DIR, split=split, transform=transform, num_points=cfg.INPUT.NUM_POINTS, load_seg=True) elif cfg.DATASET.TYPE == 'ShapeNetPartNormal': assert cfg.INPUT.USE_NORMAL dataset = D.ShapeNetPartNormal(root_dir=cfg.DATASET.ROOT_DIR, split=split, transform=transform, num_points=cfg.INPUT.NUM_POINTS, with_normal=True, load_seg=True) else: raise ValueError('Unsupported type of dataset: {}.'.format( cfg.DATASET.TYPE)) return dataset
def build_cls_dataset(cfg, mode='train'): split = cfg.DATASET[mode.upper()] is_train = (mode == 'train') transform_list = parse_augmentations(cfg, is_train) transform_list.insert(0, T.ToTensor()) transform_list.append(T.Transpose()) transform = T.Compose(transform_list) if cfg.DATASET.TYPE == 'ModelNet40H5': dataset = D.ModelNet40H5(root_dir=cfg.DATASET.ROOT_DIR, split=split, num_points=cfg.INPUT.NUM_POINTS, transform=transform) elif cfg.DATASET.TYPE == 'ModelNet40': dataset = D.ModelNet40(root_dir=cfg.DATASET.ROOT_DIR, split=split, num_points=cfg.INPUT.NUM_POINTS, normalize=True, with_normal=cfg.INPUT.USE_NORMAL, transform=transform) else: raise ValueError('Unsupported type of dataset: {}.'.format( cfg.DATASET.TYPE)) return dataset
def build_ins_seg_3d_dataset(cfg, mode='train'): is_train = (mode == 'train') augmentations = cfg.TRAIN.AUGMENTATION if is_train else cfg.TEST.AUGMENTATION transform_list = parse_augmentations(augmentations) transform_list.insert(0, T.ToTensor()) transform_list.append(T.Transpose()) transform = T.ComposeSeg(transform_list) kwargs_dict = cfg.DATASET[cfg.DATASET.TYPE].get(mode.upper(), dict()) if cfg.DATASET.TYPE == 'PartNetInsSeg': dataset = PartNetInsSeg(root_dir=cfg.DATASET.ROOT_DIR, transform=transform, **kwargs_dict) elif cfg.DATASET.TYPE == 'PartNetRegionInsSeg': dataset = PartNetRegionInsSeg(root_dir=cfg.DATASET.ROOT_DIR, transform=transform, **kwargs_dict) else: raise ValueError('Unsupported type of dataset.') return dataset
def test(cfg, output_dir=''): logger = logging.getLogger('shaper.test') # build model model, loss_fn, metric = build_model(cfg) model = nn.DataParallel(model).cuda() # model = model.cuda() # build checkpointer checkpointer = Checkpointer(model, save_dir=output_dir, logger=logger) if cfg.TEST.WEIGHT: # load weight if specified weight_path = cfg.TEST.WEIGHT.replace('@', output_dir) checkpointer.load(weight_path, resume=False) else: # load last checkpoint checkpointer.load(None, resume=True) # build data loader test_dataloader = build_dataloader(cfg, mode='test') test_dataset = test_dataloader.dataset # ---------------------------------------------------------------------------- # # Test # ---------------------------------------------------------------------------- # model.eval() loss_fn.eval() metric.eval() set_random_seed(cfg.RNG_SEED) evaluator = Evaluator(test_dataset.class_names) if cfg.TEST.VOTE.NUM_VOTE > 1: # remove old transform test_dataset.transform = None if cfg.TEST.VOTE.TYPE == 'AUGMENTATION': tmp_cfg = cfg.clone() tmp_cfg.defrost() tmp_cfg.TEST.AUGMENTATION = tmp_cfg.TEST.VOTE.AUGMENTATION transform = T.Compose([T.ToTensor()] + parse_augmentations(tmp_cfg, False) + [T.Transpose()]) transform_list = [transform] * cfg.TEST.VOTE.NUM_VOTE elif cfg.TEST.VOTE.TYPE == 'MULTI_VIEW': # build new transform transform_list = [] for view_ind in range(cfg.TEST.VOTE.NUM_VOTE): aug_type = T.RotateByAngleWithNormal if cfg.INPUT.USE_NORMAL else T.RotateByAngle rotate_by_angle = aug_type( cfg.TEST.VOTE.MULTI_VIEW.AXIS, 2 * np.pi * view_ind / cfg.TEST.VOTE.NUM_VOTE) t = [T.ToTensor(), rotate_by_angle, T.Transpose()] if cfg.TEST.VOTE.MULTI_VIEW.SHUFFLE: # Some non-deterministic algorithms, like PointNet++, benefit from shuffle. t.insert(-1, T.Shuffle()) transform_list.append(T.Compose(t)) else: raise NotImplementedError('Unsupported voting method.') with torch.no_grad(): tmp_dataloader = DataLoader(test_dataset, num_workers=1, collate_fn=lambda x: x[0]) start_time = time.time() end = start_time for ind, data in enumerate(tmp_dataloader): data_time = time.time() - end points = data['points'] # convert points into tensor points_batch = [t(points.copy()) for t in transform_list] points_batch = torch.stack(points_batch, dim=0) points_batch = points_batch.cuda(non_blocking=True) preds = model({'points': points_batch}) cls_logit_batch = preds['cls_logit'].cpu().numpy( ) # (batch_size, num_classes) cls_logit_ensemble = np.mean(cls_logit_batch, axis=0) pred_label = np.argmax(cls_logit_ensemble) evaluator.update(pred_label, data['cls_label']) batch_time = time.time() - end end = time.time() if cfg.TEST.LOG_PERIOD > 0 and ind % cfg.TEST.LOG_PERIOD == 0: logger.info('iter: {:4d} time:{:.4f} data:{:.4f}'.format( ind, batch_time, data_time)) test_time = time.time() - start_time logger.info('Test total time: {:.2f}s'.format(test_time)) else: test_meters = MetricLogger(delimiter=' ') test_meters.bind(metric) with torch.no_grad(): start_time = time.time() end = start_time for iteration, data_batch in enumerate(test_dataloader): data_time = time.time() - end cls_label_batch = data_batch['cls_label'].numpy() data_batch = { k: v.cuda(non_blocking=True) for k, v in data_batch.items() } preds = model(data_batch) loss_dict = loss_fn(preds, data_batch) total_loss = sum(loss_dict.values()) test_meters.update(loss=total_loss, **loss_dict) metric.update_dict(preds, data_batch) cls_logit_batch = preds['cls_logit'].cpu().numpy( ) # (batch_size, num_classes) pred_label_batch = np.argmax(cls_logit_batch, axis=1) evaluator.batch_update(pred_label_batch, cls_label_batch) batch_time = time.time() - end end = time.time() test_meters.update(time=batch_time, data=data_time) if cfg.TEST.LOG_PERIOD > 0 and iteration % cfg.TEST.LOG_PERIOD == 0: logger.info( test_meters.delimiter.join([ 'iter: {iter:4d}', '{meters}', ]).format( iter=iteration, meters=str(test_meters), )) test_time = time.time() - start_time logger.info('Test {} total time: {:.2f}s'.format( test_meters.summary_str, test_time)) # evaluate logger.info('overall accuracy={:.2f}%'.format(100.0 * evaluator.overall_accuracy)) logger.info('average class accuracy={:.2f}%.\n{}'.format( 100.0 * np.nanmean(evaluator.class_accuracy), evaluator.print_table())) evaluator.save_table(osp.join(output_dir, 'eval.cls.tsv'))