def get_dataset(mode, cfg, return_idx=False, return_category=False): ''' Returns the dataset. Args: model (nn.Module): the model which is used cfg (dict): config dictionary return_idx (bool): whether to include an ID field ''' method = cfg['method'] dataset_type = cfg['data']['dataset'] dataset_folder = cfg['data']['path'] categories = cfg['data']['classes'] # Get split splits = { 'train': cfg['data']['train_split'], 'val': cfg['data']['val_split'], 'test': cfg['data']['test_split'], } split = splits[mode] # Create dataset if dataset_type == 'Shapes3D': # Dataset fields # Method specific fields (usually correspond to output) fields = method_dict[method].config.get_data_fields(mode, cfg) # Input fields inputs_field = get_inputs_field(mode, cfg) if inputs_field is not None: fields['inputs'] = inputs_field if return_idx: fields['idx'] = data.IndexField() if return_category: fields['category'] = data.CategoryField() dataset = data.Shapes3dDataset(dataset_folder, fields, split=split, categories=categories) else: raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset']) return dataset
unpackbits=cfg['data']['points_unpackbits'], multi_files=cfg['data']['multi_files']) pointcloud_field = data.PointCloudField(cfg['data']['pointcloud_chamfer_file'], multi_files=cfg['data']['multi_files']) fields = { 'points_iou': points_field, 'pointcloud_chamfer': pointcloud_field, 'idx': data.IndexField(), } print('Test split: ', cfg['data']['test_split']) dataset_folder = cfg['data']['path'] dataset = data.Shapes3dDataset(dataset_folder, fields, cfg['data']['test_split'], categories=cfg['data']['classes'], cfg=cfg) # Evaluator evaluator = MeshEvaluator(n_points=100000) # Loader test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False) # Evaluate all classes eval_dicts = [] print('Evaluating meshes...')