def __init__(self, root: str, cache_dir: str, categories: list = ['chair'], train: bool = True, split: float = .7, resolutions=[128, 32], no_progress: bool = False): self.root = Path(root) self.cache_dir = Path(cache_dir) / 'voxels' self.cache_transforms = {} self.params = { 'resolutions': resolutions, } mesh_dataset = ShapeNet_Meshes(root=root, categories=categories, train=train, split=split, no_progress=no_progress) self.names = mesh_dataset.names self.synset_idxs = mesh_dataset.synset_idxs self.synsets = mesh_dataset.synsets self.labels = mesh_dataset.labels for res in self.params['resolutions']: self.cache_transforms[res] = tfs.CacheCompose([ tfs.TriangleMeshToVoxelGrid(res, normalize=False, vertex_offset=0.5), tfs.FillVoxelGrid(thresh=0.5), tfs.ExtractProjectOdmsFromVoxelGrid() ], self.cache_dir) desc = 'converting to voxels' for idx in tqdm(range(len(mesh_dataset)), desc=desc, disable=no_progress): name = mesh_dataset.names[idx] if name not in self.cache_transforms[res].cached_ids: sample = mesh_dataset[idx] mesh = TriangleMesh.from_tensors(sample['data']['vertices'], sample['data']['faces']) self.cache_transforms[res](name, mesh)
def __init__(self, basedir: str, cache_dir: Optional[str] = None, split: Optional[str] = 'train', categories: list = ['bed'], resolutions: List[int] = [32], device: Optional[Union[torch.device, str]] = 'cpu'): self.basedir = basedir self.device = torch.device(device) self.cache_dir = cache_dir if cache_dir is not None else os.path.join( basedir, 'cache') self.params = {'resolutions': resolutions} self.cache_transforms = {} mesh_dataset = ModelNet(basedir=basedir, split=split, categories=categories, device=device) self.names = mesh_dataset.names self.categories = mesh_dataset.categories self.cat_idxs = mesh_dataset.cat_idxs for res in self.params['resolutions']: self.cache_transforms[res] = tfs.CacheCompose([ tfs.TriangleMeshToVoxelGrid( res, normalize=True, vertex_offset=0.5), tfs.FillVoxelGrid(thresh=0.5), tfs.ExtractProjectOdmsFromVoxelGrid() ], self.cache_dir) desc = 'converting to voxels to resolution {0}'.format(res) for idx in tqdm(range(len(mesh_dataset)), desc=desc, disable=False): name = mesh_dataset.names[idx] if name not in self.cache_transforms[res].cached_ids: mesh, _ = mesh_dataset[idx] mesh.to(device=device) self.cache_transforms[res](name, mesh)